diff --git a/src/lib/orm/error.go b/src/lib/orm/error.go index c4afc288e..af5b2e727 100644 --- a/src/lib/orm/error.go +++ b/src/lib/orm/error.go @@ -54,8 +54,7 @@ func AsNotFoundError(err error, messageFormat string, args ...interface{}) *erro // AsConflictError checks whether the err is duplicate key error. If it it, wrap it // as a src/internal/error.Error with conflict error code, else return nil func AsConflictError(err error, messageFormat string, args ...interface{}) *errors.Error { - var pqErr *pq.Error - if errors.As(err, &pqErr) && pqErr.Code == "23505" { + if isDuplicateKeyError(err) { e := errors.New(err). WithCode(errors.ConflictCode). WithMessage(messageFormat, args...) @@ -67,8 +66,7 @@ func AsConflictError(err error, messageFormat string, args ...interface{}) *erro // AsForeignKeyError checks whether the err is violating foreign key constraint error. If it it, wrap it // as a src/internal/error.Error with violating foreign key constraint error code, else return nil func AsForeignKeyError(err error, messageFormat string, args ...interface{}) *errors.Error { - var pqErr *pq.Error - if errors.As(err, &pqErr) && pqErr.Code == "23503" { + if isViolatingForeignKeyConstraintError(err) { e := errors.New(err). WithCode(errors.ViolateForeignKeyConstraintCode). WithMessage(messageFormat, args...) @@ -76,3 +74,21 @@ func AsForeignKeyError(err error, messageFormat string, args ...interface{}) *er } return nil } + +func isDuplicateKeyError(err error) bool { + var pqErr *pq.Error + if errors.As(err, &pqErr) && pqErr.Code == "23505" { + return true + } + + return false +} + +func isViolatingForeignKeyConstraintError(err error) bool { + var pqErr *pq.Error + if errors.As(err, &pqErr) && pqErr.Code == "23503" { + return true + } + + return false +} diff --git a/src/lib/orm/orm.go b/src/lib/orm/orm.go index 123f5da42..ec4e07255 100644 --- a/src/lib/orm/orm.go +++ b/src/lib/orm/orm.go @@ -17,6 +17,7 @@ package orm import ( "context" "errors" + "fmt" "github.com/astaxie/beego/orm" "github.com/goharbor/harbor/src/lib/log" @@ -87,3 +88,61 @@ func WithTransaction(f func(ctx context.Context) error) func(ctx context.Context return nil } } + +// ReadOrCreate read or create instance to datebase, retry to read when met a duplicate key error after the creating +func ReadOrCreate(ctx context.Context, md interface{}, col1 string, cols ...string) (created bool, id int64, err error) { + getter, ok := md.(interface { + GetID() int64 + }) + + if !ok { + err = fmt.Errorf("missing GetID method for the model %T", md) + return + } + + defer func() { + if !created && err == nil { // found in the database + id = getter.GetID() + } + }() + + o, err := FromContext(ctx) + if err != nil { + return + } + + cols = append([]string{col1}, cols...) + + err = o.Read(md, cols...) + if err == nil { // found in the database + return + } + + if !errors.Is(err, orm.ErrNoRows) { // met a error when read database + return + } + + // not found in the database, try to create one + err = WithTransaction(func(ctx context.Context) error { + o, err := FromContext(ctx) + if err != nil { + return err + } + + id, err = o.Insert(md) + return err + })(ctx) + + if err == nil { // create success + created = true + + return + } + + // got a duplicate key error, try to read again + if isDuplicateKeyError(err) { + err = o.Read(md, cols...) + } + + return +} diff --git a/src/lib/orm/orm_test.go b/src/lib/orm/orm_test.go index b57ffdc85..97745a5ed 100644 --- a/src/lib/orm/orm_test.go +++ b/src/lib/orm/orm_test.go @@ -17,6 +17,7 @@ package orm import ( "context" "errors" + "sync" "testing" "github.com/astaxie/beego/orm" @@ -29,10 +30,14 @@ type Foo struct { Name string `orm:"column(name)"` } -func (*Foo) TableName() string { +func (foo *Foo) TableName() string { return "foo" } +func (foo *Foo) GetID() int64 { + return foo.ID +} + func addFoo(ctx context.Context, foo Foo) (int64, error) { o, err := FromContext(ctx) if err != nil { @@ -349,6 +354,61 @@ func (suite *OrmSuite) TestNestedSavepoint() { suite.False(existFoo(ctx, id2)) } +func (suite *OrmSuite) TestReadOrCreate() { + ctx := NewContext(context.TODO(), orm.NewOrm()) + + var id int64 + f1 := func(ctx context.Context) (err error) { + created1, id1, err := ReadOrCreate(ctx, &Foo{Name: "n1"}, "name") + suite.NoError(err) + suite.True(created1) + + created2, id2, err := ReadOrCreate(ctx, &Foo{Name: "n1"}, "name") + suite.NoError(err) + suite.False(created2) + + suite.Equal(id2, id1) + + id = id1 + + return nil + } + + suite.NoError(WithTransaction(f1)(ctx)) + suite.True(existFoo(ctx, id)) +} + +func (suite *OrmSuite) TestReadOrCreateParallel() { + count := 500 + + arr := make([]int, count) + + var wg sync.WaitGroup + for i := 0; i < count; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + + ctx := NewContext(context.TODO(), orm.NewOrm()) + created, _, err := ReadOrCreate(ctx, &Foo{Name: "n2"}, "name") + suite.NoError(err) + + if created { + arr[i] = 1 + } + }(i) + } + + wg.Wait() + + sum := 0 + for _, v := range arr { + sum += v + } + + suite.Equal(1, sum) +} + func TestRunOrmSuite(t *testing.T) { suite.Run(t, new(OrmSuite)) } diff --git a/src/pkg/scan/dao/scan/model.go b/src/pkg/scan/dao/scan/model.go index 0b483c2ab..1c92938fa 100644 --- a/src/pkg/scan/dao/scan/model.go +++ b/src/pkg/scan/dao/scan/model.go @@ -70,6 +70,23 @@ type VulnerabilityRecord struct { VendorAttributes string `orm:"column(vendor_attributes);type(json);null"` } +// TableName for VulnerabilityRecord +func (vr *VulnerabilityRecord) TableName() string { + return "vulnerability_record" +} + +// TableUnique for VulnerabilityRecord +func (vr *VulnerabilityRecord) TableUnique() [][]string { + return [][]string{ + {"cve_id", "registration_uuid", "package", "package_version"}, + } +} + +// GetID returns the ID of the record +func (vr *VulnerabilityRecord) GetID() int64 { + return vr.ID +} + // ReportVulnerabilityRecord is relation table required to optimize data storage for both the // vulnerability records and the scan report. // identified by composite key (ID, Report) @@ -83,18 +100,6 @@ type ReportVulnerabilityRecord struct { VulnRecordID int64 `orm:"column(vuln_record_id);"` } -// TableName for VulnerabilityRecord -func (vr *VulnerabilityRecord) TableName() string { - return "vulnerability_record" -} - -// TableUnique for VulnerabilityRecord -func (vr *VulnerabilityRecord) TableUnique() [][]string { - return [][]string{ - {"cve_id", "registration_uuid", "package", "package_version"}, - } -} - // TableName for ReportVulnerabilityRecord func (rvr *ReportVulnerabilityRecord) TableName() string { return "report_vulnerability_record" @@ -106,3 +111,8 @@ func (rvr *ReportVulnerabilityRecord) TableUnique() [][]string { {"report_uuid", "vuln_record_id"}, } } + +// GetID returns the ID of the record +func (rvr *ReportVulnerabilityRecord) GetID() int64 { + return rvr.ID +} diff --git a/src/pkg/scan/dao/scan/vulnerability.go b/src/pkg/scan/dao/scan/vulnerability.go index dba0ebfac..264633ed5 100644 --- a/src/pkg/scan/dao/scan/vulnerability.go +++ b/src/pkg/scan/dao/scan/vulnerability.go @@ -17,7 +17,7 @@ package scan import ( "context" "fmt" - "github.com/goharbor/harbor/src/lib/errors" + "github.com/goharbor/harbor/src/lib/orm" "github.com/goharbor/harbor/src/lib/q" ) @@ -60,19 +60,8 @@ type vulnerabilityRecordDao struct{} // Create creates new vulnerability record. func (v *vulnerabilityRecordDao) Create(ctx context.Context, vr *VulnerabilityRecord) (int64, error) { - o, err := orm.FromContext(ctx) - var vrID int64 - err = orm.WithTransaction(func(ctx context.Context) error { - var err error - vrID, err = o.InsertOrUpdate(vr, "cve_id, registration_uuid, package, package_version") - return orm.WrapConflictError(err, "vulnerability already exists") - })(ctx) - if errors.IsConflictErr(err) { - if err := o.Read(vr, "cve_id", "registration_uuid", "package", "package_version"); err != nil { - return 0, err - } - return vr.ID, nil - } + _, vrID, err := orm.ReadOrCreate(ctx, vr, "cve_id", "registration_uuid", "package", "package_version") + return vrID, err } @@ -137,11 +126,7 @@ func (v *vulnerabilityRecordDao) InsertForReport(ctx context.Context, reportUUID rvr.Report = reportUUID rvr.VulnRecordID = vrID - o, err := orm.FromContext(ctx) - if err != nil { - return 0, err - } - _, rvrID, err := o.ReadOrCreate(rvr, "report_uuid", "vuln_record_id") + _, rvrID, err := orm.ReadOrCreate(ctx, rvr, "report_uuid", "vuln_record_id") return rvrID, err @@ -164,8 +149,8 @@ func (v *vulnerabilityRecordDao) GetForReport(ctx context.Context, reportUUID st if err != nil { return nil, err } - query := `select vulnerability_record.* from vulnerability_record - inner join report_vulnerability_record on + query := `select vulnerability_record.* from vulnerability_record + inner join report_vulnerability_record on vulnerability_record.id = report_vulnerability_record.vuln_record_id and report_vulnerability_record.report_uuid=?` _, err = o.Raw(query, reportUUID).QueryRows(&vulnRecs) return vulnRecs, err diff --git a/src/pkg/scan/dao/scan/vulnerability_test.go b/src/pkg/scan/dao/scan/vulnerability_test.go index 22194258e..051d860b9 100644 --- a/src/pkg/scan/dao/scan/vulnerability_test.go +++ b/src/pkg/scan/dao/scan/vulnerability_test.go @@ -1,17 +1,32 @@ +// Copyright Project Harbor Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package scan import ( "fmt" + "testing" + "github.com/goharbor/harbor/src/jobservice/job" "github.com/goharbor/harbor/src/lib/orm" "github.com/goharbor/harbor/src/lib/q" "github.com/goharbor/harbor/src/pkg/scan/dao/scanner" - "github.com/goharbor/harbor/src/pkg/scan/rest/v1" + v1 "github.com/goharbor/harbor/src/pkg/scan/rest/v1" htesting "github.com/goharbor/harbor/src/testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" - "testing" ) const sampleReportWithCompleteVulnData = `{