diff --git a/src/core/middlewares/middlewares.go b/src/core/middlewares/middlewares.go index 84bde8107..082c5adaf 100644 --- a/src/core/middlewares/middlewares.go +++ b/src/core/middlewares/middlewares.go @@ -34,6 +34,7 @@ import ( "github.com/goharbor/harbor/src/server/middleware/session" "github.com/goharbor/harbor/src/server/middleware/trace" "github.com/goharbor/harbor/src/server/middleware/transaction" + "github.com/goharbor/harbor/src/server/middleware/url" ) var ( @@ -79,6 +80,7 @@ var ( // MiddleWares returns global middlewares func MiddleWares() []beego.MiddleWare { return []beego.MiddleWare{ + url.Middleware(), mergeslash.Middleware(), trace.Middleware(), metric.Middleware(), diff --git a/src/server/middleware/url/parse.go b/src/server/middleware/url/parse.go new file mode 100644 index 000000000..5d2aaec5e --- /dev/null +++ b/src/server/middleware/url/parse.go @@ -0,0 +1,24 @@ +package url + +import ( + "net/http" + "net/url" + + "github.com/goharbor/harbor/src/lib/errors" + lib_http "github.com/goharbor/harbor/src/lib/http" + "github.com/goharbor/harbor/src/server/middleware" +) + +// Middleware middleware which validates the raw query, especially for the invalid semicolon separator. +func Middleware(skippers ...middleware.Skipper) func(http.Handler) http.Handler { + return middleware.New(func(w http.ResponseWriter, r *http.Request, next http.Handler) { + if r.URL != nil && r.URL.RawQuery != "" { + _, err := url.ParseQuery(r.URL.RawQuery) + if err != nil { + lib_http.SendError(w, errors.New(err).WithCode(errors.BadRequestCode)) + return + } + } + next.ServeHTTP(w, r) + }, skippers...) +} diff --git a/src/server/middleware/url/parse_test.go b/src/server/middleware/url/parse_test.go new file mode 100644 index 000000000..aa37a4d75 --- /dev/null +++ b/src/server/middleware/url/parse_test.go @@ -0,0 +1,36 @@ +package url + +import ( + "github.com/stretchr/testify/assert" + "net/http" + "net/http/httptest" + "testing" +) + +func TestURL(t *testing.T) { + assert := assert.New(t) + + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + req1 := httptest.NewRequest(http.MethodPost, "/req1?mount=sha256&from=test", nil) + rec1 := httptest.NewRecorder() + Middleware()(next).ServeHTTP(rec1, req1) + assert.Equal(http.StatusOK, rec1.Code) + + req2 := httptest.NewRequest(http.MethodPost, "/req2?mount=sha256&from=test;", nil) + rec2 := httptest.NewRecorder() + Middleware()(next).ServeHTTP(rec2, req2) + assert.Equal(http.StatusBadRequest, rec2.Code) + + req3 := httptest.NewRequest(http.MethodGet, "/req3?foo=bar?", nil) + rec3 := httptest.NewRecorder() + Middleware()(next).ServeHTTP(rec3, req3) + assert.Equal(http.StatusOK, rec3.Code) + + req4 := httptest.NewRequest(http.MethodGet, "/req4", nil) + rec4 := httptest.NewRecorder() + Middleware()(next).ServeHTTP(rec4, req4) + assert.Equal(http.StatusOK, rec4.Code) +}