diff --git a/api/v2.0/swagger.yaml b/api/v2.0/swagger.yaml index 1f06b7192..fa7dbc354 100644 --- a/api/v2.0/swagger.yaml +++ b/api/v2.0/swagger.yaml @@ -6099,7 +6099,8 @@ paths: repository_name(exact match) project_id(exact match) package(exact match) - and tag(exact match) + tag(exact match) + digest(exact match) tags: - securityhub operationId: ListVulnerabilities diff --git a/src/pkg/securityhub/dao/security.go b/src/pkg/securityhub/dao/security.go index 9ccd0fe95..c38839643 100644 --- a/src/pkg/securityhub/dao/security.go +++ b/src/pkg/securityhub/dao/security.go @@ -104,18 +104,24 @@ where a.digest = s.digest ) type filterMetaData struct { - DataType string + // DataType is the data type of the filter, it could be stringType, rangeType + DataType string + // ColumnName is the column name in the database, if it is empty, the key will be used as the column name + ColumnName string + // FilterFunc is the function to generate the filter sql, default is exactMatchFilter FilterFunc func(ctx context.Context, key string, query *q.Query) (sqlStr string, params []interface{}) } +// filterMap define the query condition var filterMap = map[string]*filterMetaData{ - "cve_id": &filterMetaData{DataType: stringType, FilterFunc: exactMatchFilter}, - "severity": &filterMetaData{DataType: stringType, FilterFunc: exactMatchFilter}, + "cve_id": &filterMetaData{DataType: stringType}, + "severity": &filterMetaData{DataType: stringType}, "cvss_score_v3": &filterMetaData{DataType: rangeType, FilterFunc: rangeFilter}, - "project_id": &filterMetaData{DataType: stringType, FilterFunc: exactMatchFilter}, - "repository_name": &filterMetaData{DataType: stringType, FilterFunc: exactMatchFilter}, - "package": &filterMetaData{DataType: stringType, FilterFunc: exactMatchFilter}, + "project_id": &filterMetaData{DataType: stringType}, + "repository_name": &filterMetaData{DataType: stringType}, + "package": &filterMetaData{DataType: stringType}, "tag": &filterMetaData{DataType: stringType, FilterFunc: tagFilter}, + "digest": &filterMetaData{DataType: stringType, ColumnName: "a.digest"}, } var applyFilterFunc func(ctx context.Context, key string, query *q.Query) (sqlStr string, params []interface{}) @@ -125,7 +131,11 @@ func exactMatchFilter(ctx context.Context, key string, query *q.Query) (sqlStr s return } if val, ok := query.Keywords[key]; ok { - sqlStr = fmt.Sprintf(" and %v = ?", key) + col := key + if len(filterMap[key].ColumnName) > 0 { + col = filterMap[key].ColumnName + } + sqlStr = fmt.Sprintf(" and %v = ?", col) params = append(params, val) return } @@ -325,6 +335,9 @@ func applyVulFilter(ctx context.Context, sqlStr string, query *q.Query, params [ queryStr = sqlStr newParam = params for k, m := range filterMap { + if m.FilterFunc == nil { + m.FilterFunc = exactMatchFilter // default filter function is exactMatchFilter + } s, p := m.FilterFunc(ctx, k, query) queryStr = queryStr + s newParam = append(newParam, p...) diff --git a/src/pkg/securityhub/dao/security_test.go b/src/pkg/securityhub/dao/security_test.go index c701699cb..6b792196d 100644 --- a/src/pkg/securityhub/dao/security_test.go +++ b/src/pkg/securityhub/dao/security_test.go @@ -16,6 +16,7 @@ package dao import ( "context" + "strings" "testing" "github.com/stretchr/testify/suite" @@ -60,6 +61,7 @@ values (1003, 1, 'library/hello-world', 'digest1003', 'IMAGE', '2023-06-02 09:16 `insert into scanner_registration (name, url, uuid, auth) values('trivy', 'https://www.vmware.com', 'ruuid', 'empty')`, `insert into vulnerability_record (id, cve_id, registration_uuid, cvss_score_v3) values (1, '2023-4567-12345', 'ruuid', 9.8)`, `insert into report_vulnerability_record (report_uuid, vuln_record_id) VALUES ('uuid', 1)`, + `INSERT INTO tag (repository_id, artifact_id, name) VALUES (1, (select id from artifact where repository_name = 'library/hello-world' limit 1), 'tag_test')`, }) testDao.ExecuteBatchSQL([]string{ @@ -85,6 +87,7 @@ func (suite *SecurityDaoTestSuite) TearDownTest() { `delete from vulnerability_record where cve_id='2023-4567-12345'`, `delete from report_vulnerability_record where report_uuid='ruuid'`, `delete from vulnerability_record where registration_uuid ='uuid2'`, + `delete from tag where name='tag_test'`, }) } @@ -128,13 +131,13 @@ func Test_checkQFilter(t *testing.T) { args args wantErr bool }{ - {"happy_path", args{q.New(q.KeyWords{"sample": 1}), map[string]*filterMetaData{"sample": &filterMetaData{intType, exactMatchFilter}}}, false}, - {"happy_path_cve_id", args{q.New(q.KeyWords{"cve_id": "CVE-2023-2345"}), map[string]*filterMetaData{"cve_id": &filterMetaData{stringType, exactMatchFilter}}}, false}, - {"happy_path_severity", args{q.New(q.KeyWords{"severity": "Critical"}), map[string]*filterMetaData{"severity": &filterMetaData{stringType, exactMatchFilter}}}, false}, - {"happy_path_cvss_score_v3", args{q.New(q.KeyWords{"cvss_score_v3": &q.Range{Min: 2.0, Max: 3.0}}), map[string]*filterMetaData{"cvss_score_v3": &filterMetaData{rangeType, rangeFilter}}}, false}, + {"happy_path", args{q.New(q.KeyWords{"sample": 1}), map[string]*filterMetaData{"sample": &filterMetaData{DataType: intType}}}, false}, + {"happy_path_cve_id", args{q.New(q.KeyWords{"cve_id": "CVE-2023-2345"}), map[string]*filterMetaData{"cve_id": &filterMetaData{DataType: stringType}}}, false}, + {"happy_path_severity", args{q.New(q.KeyWords{"severity": "Critical"}), map[string]*filterMetaData{"severity": &filterMetaData{DataType: stringType}}}, false}, + {"happy_path_cvss_score_v3", args{q.New(q.KeyWords{"cvss_score_v3": &q.Range{Min: 2.0, Max: 3.0}}), map[string]*filterMetaData{"cvss_score_v3": &filterMetaData{DataType: rangeType, FilterFunc: rangeFilter}}}, false}, {"unhappy_path", args{q.New(q.KeyWords{"sample": 1}), map[string]*filterMetaData{"a": &filterMetaData{DataType: intType}}}, true}, - {"unhappy_path2", args{q.New(q.KeyWords{"cve_id": 1}), map[string]*filterMetaData{"cve_id": &filterMetaData{stringType, exactMatchFilter}}}, true}, - {"unhappy_path3", args{q.New(q.KeyWords{"severity": &q.Range{Min: 2.0, Max: 10.0}}), map[string]*filterMetaData{"severity": &filterMetaData{stringType, exactMatchFilter}}}, true}, + {"unhappy_path2", args{q.New(q.KeyWords{"cve_id": 1}), map[string]*filterMetaData{"cve_id": &filterMetaData{DataType: stringType}}}, true}, + {"unhappy_path3", args{q.New(q.KeyWords{"severity": &q.Range{Min: 2.0, Max: 10.0}}), map[string]*filterMetaData{"severity": &filterMetaData{DataType: stringType}}}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -158,6 +161,7 @@ func (suite *SecurityDaoTestSuite) TestExacthMatchFilter() { wantParams []interface{} }{ {"normal", args{suite.Context(), "cve_id", q.New(q.KeyWords{"cve_id": "CVE-2023-2345"})}, " and cve_id = ?", []interface{}{"CVE-2023-2345"}}, + {"digest", args{suite.Context(), "digest", q.New(q.KeyWords{"digest": "digest123"})}, " and a.digest = ?", []interface{}{"digest123"}}, } for _, tt := range tests { suite.Run(tt.name, func() { @@ -207,3 +211,50 @@ func (suite *SecurityDaoTestSuite) TestListVul() { suite.NoError(err) suite.Equal(1, len(vuls)) } + +func (suite *SecurityDaoTestSuite) TestTagFilter() { + type args struct { + ctx context.Context + key string + query *q.Query + } + tests := []struct { + name string + args args + wantSqlStr string + wantParams []interface{} + }{ + {"normal", args{suite.Context(), "tag", q.New(q.KeyWords{"tag": "tag_test"})}, " and a.id IN", nil}, + } + for _, tt := range tests { + suite.Run(tt.name, func() { + gotSqlStr, gotParams := tagFilter(tt.args.ctx, tt.args.key, tt.args.query) + suite.True(strings.Contains(gotSqlStr, tt.wantSqlStr), "tagFilter() gotSqlStr = %v, want %v", gotSqlStr, tt.wantSqlStr) + suite.Equal(gotParams, tt.wantParams, "tagFilter() gotParams = %v, want %v", gotParams, tt.wantParams) + }) + } +} + +func (suite *SecurityDaoTestSuite) TestApplyVulFilter() { + type args struct { + ctx context.Context + sqlStr string + query *q.Query + params []interface{} + } + tests := []struct { + name string + args args + wantSqlStr string + wantParams []interface{} + }{ + {"normal", args{suite.Context(), "select * from vulnerability_record", q.New(q.KeyWords{"tag": "tag_test"}), nil}, " and a.id IN", nil}, + } + for _, tt := range tests { + suite.Run(tt.name, func() { + gotSqlStr, gotParams := applyVulFilter(tt.args.ctx, tt.args.sqlStr, tt.args.query, tt.args.params) + suite.True(strings.Contains(gotSqlStr, tt.wantSqlStr), "applyVulFilter() gotSqlStr = %v, want %v", gotSqlStr, tt.wantSqlStr) + suite.Equal(gotParams, tt.wantParams, "applyVulFilter() gotParams = %v, want %v", gotParams, tt.wantParams) + }) + } +}