diff --git a/src/common/utils/utils.go b/src/common/utils/utils.go index b9f11300e..cea54d342 100644 --- a/src/common/utils/utils.go +++ b/src/common/utils/utils.go @@ -234,3 +234,24 @@ func GetStrValueOfAnyType(value interface{}) string { } return strVal } + +// IsIllegalLength ... +func IsIllegalLength(s string, min int, max int) bool { + if min == -1 { + return (len(s) > max) + } + if max == -1 { + return (len(s) <= min) + } + return (len(s) < min || len(s) > max) +} + +// IsContainIllegalChar ... +func IsContainIllegalChar(s string, illegalChar []string) bool { + for _, c := range illegalChar { + if strings.Index(s, c) >= 0 { + return true + } + } + return false +} diff --git a/src/core/api/project.go b/src/core/api/project.go index 7771b7d77..3903f3861 100644 --- a/src/core/api/project.go +++ b/src/core/api/project.go @@ -508,10 +508,10 @@ func (p *ProjectAPI) Logs() { p.ServeJSON() } -// TODO move this to package models +// TODO move this to pa ckage models func validateProjectReq(req *models.ProjectRequest) error { pn := req.Name - if isIllegalLength(req.Name, projectNameMinLen, projectNameMaxLen) { + if utils.IsIllegalLength(req.Name, projectNameMinLen, projectNameMaxLen) { return fmt.Errorf("Project name is illegal in length. (greater than %d or less than %d)", projectNameMaxLen, projectNameMinLen) } validProjectName := regexp.MustCompile(`^` + restrictedNameChars + `$`) diff --git a/src/core/api/user.go b/src/core/api/user.go index f797096dd..765ed9da6 100644 --- a/src/core/api/user.go +++ b/src/core/api/user.go @@ -16,11 +16,6 @@ package api import ( "fmt" - "net/http" - "regexp" - "strconv" - "strings" - "github.com/goharbor/harbor/src/common" "github.com/goharbor/harbor/src/common/dao" "github.com/goharbor/harbor/src/common/models" @@ -29,6 +24,9 @@ import ( "github.com/goharbor/harbor/src/common/utils" "github.com/goharbor/harbor/src/common/utils/log" "github.com/goharbor/harbor/src/core/config" + "net/http" + "regexp" + "strconv" ) // UserAPI handles request to /api/users/{} @@ -446,13 +444,13 @@ func (ua *UserAPI) modifiable() bool { // validate only validate when user register func validate(user models.User) error { - if isIllegalLength(user.Username, 1, 255) { + if utils.IsIllegalLength(user.Username, 1, 255) { return fmt.Errorf("username with illegal length") } - if isContainIllegalChar(user.Username, []string{",", "~", "#", "$", "%"}) { + if utils.IsContainIllegalChar(user.Username, []string{",", "~", "#", "$", "%"}) { return fmt.Errorf("username contains illegal characters") } - if isIllegalLength(user.Password, 8, 20) { + if utils.IsIllegalLength(user.Password, 8, 20) { return fmt.Errorf("password with illegal length") } return commonValidate(user) @@ -469,35 +467,16 @@ func commonValidate(user models.User) error { return fmt.Errorf("Email can't be empty") } - if isIllegalLength(user.Realname, 1, 255) { + if utils.IsIllegalLength(user.Realname, 1, 255) { return fmt.Errorf("realname with illegal length") } - if isContainIllegalChar(user.Realname, []string{",", "~", "#", "$", "%"}) { + if utils.IsContainIllegalChar(user.Realname, []string{",", "~", "#", "$", "%"}) { return fmt.Errorf("realname contains illegal characters") } - if isIllegalLength(user.Comment, -1, 30) { + if utils.IsIllegalLength(user.Comment, -1, 30) { return fmt.Errorf("comment with illegal length") } return nil } - -func isIllegalLength(s string, min int, max int) bool { - if min == -1 { - return (len(s) > max) - } - if max == -1 { - return (len(s) <= min) - } - return (len(s) < min || len(s) > max) -} - -func isContainIllegalChar(s string, illegalChar []string) bool { - for _, c := range illegalChar { - if strings.Index(s, c) >= 0 { - return true - } - } - return false -} diff --git a/src/core/controllers/oidc.go b/src/core/controllers/oidc.go index 7ba18580d..be1c9ff84 100644 --- a/src/core/controllers/oidc.go +++ b/src/core/controllers/oidc.go @@ -15,13 +15,18 @@ package controllers import ( + "encoding/json" "fmt" "github.com/goharbor/harbor/src/common" + "github.com/goharbor/harbor/src/common/dao" + "github.com/goharbor/harbor/src/common/models" "github.com/goharbor/harbor/src/common/utils" "github.com/goharbor/harbor/src/common/utils/oidc" "github.com/goharbor/harbor/src/core/api" "github.com/goharbor/harbor/src/core/config" + "github.com/pkg/errors" "net/http" + "strings" ) const idTokenKey = "oidc_id_token" @@ -84,7 +89,12 @@ func (oc *OIDCController) Callback() { oc.RenderFormatedError(http.StatusInternalServerError, err) return } - oc.SetSession(idTokenKey, token.IDToken) + ouDataStr, err := json.Marshal(d) + if err != nil { + oc.RenderFormatedError(http.StatusInternalServerError, err) + return + } + oc.SetSession(idTokenKey, string(ouDataStr)) // TODO: check and trigger onboard popup or redirect user to project page oc.Data["json"] = d oc.ServeFormatted() @@ -92,7 +102,49 @@ func (oc *OIDCController) Callback() { // Onboard handles the request to onboard an user authenticated via OIDC provider func (oc *OIDCController) Onboard() { - oc.RenderError(http.StatusNotImplemented, "") - return + username := oc.GetString("username") + if utils.IsIllegalLength(username, 1, 255) { + oc.RenderFormatedError(http.StatusBadRequest, errors.New("username with illegal length")) + return + } + if utils.IsContainIllegalChar(username, []string{",", "~", "#", "$", "%"}) { + oc.RenderFormatedError(http.StatusBadRequest, errors.New("username contains illegal characters")) + return + } + + idTokenStr := oc.GetSession(idTokenKey) + d := &oidcUserData{} + err := json.Unmarshal([]byte(idTokenStr.(string)), &d) + if err != nil { + oc.RenderFormatedError(http.StatusInternalServerError, err) + return + } + oidcUser := models.OIDCUser{ + SubIss: d.Subject + d.Issuer, + // TODO: get secret with secret manager. + Secret: utils.GenerateRandomString(), + } + + var email string + if d.Email == "" { + email = utils.GenerateRandomString() + "@harbor.com" + } + user := models.User{ + Username: username, + Email: email, + OIDCUserMeta: &oidcUser, + } + + err = dao.OnBoardOIDCUser(&user) + if err != nil { + if strings.Contains(err.Error(), dao.ErrDupUser.Error()) { + oc.RenderFormatedError(http.StatusConflict, err) + return + } + oc.RenderFormatedError(http.StatusInternalServerError, err) + return + } + + oc.Controller.Redirect(config.GetPortalURL(), http.StatusMovedPermanently) }