diff --git a/utils/registry/auth/authorizer.go b/utils/registry/auth/authorizer.go index cea731246..26ea177a5 100644 --- a/utils/registry/auth/authorizer.go +++ b/utils/registry/auth/authorizer.go @@ -46,8 +46,8 @@ func NewRequestAuthorizer(handlers []Handler, challenges []au.Challenge) *Reques // ModifyRequest adds authorization to the request func (r *RequestAuthorizer) ModifyRequest(req *http.Request) error { - for _, handler := range r.handlers { - for _, challenge := range r.challenges { + for _, challenge := range r.challenges { + for _, handler := range r.handlers { if handler.Scheme() == challenge.Scheme { if err := handler.AuthorizeRequest(req, challenge.Parameters); err != nil { return err diff --git a/utils/registry/auth/tokenhandler.go b/utils/registry/auth/tokenhandler.go index f546bac0c..d734ee935 100644 --- a/utils/registry/auth/tokenhandler.go +++ b/utils/registry/auth/tokenhandler.go @@ -168,6 +168,7 @@ func (s *standardTokenHandler) generateToken(realm, service string, scopes []str if resp.StatusCode != http.StatusOK { err = registry_errors.Error{ StatusCode: resp.StatusCode, + StatusText: resp.Status, Message: string(b), } return diff --git a/utils/registry/errors/error.go b/utils/registry/errors/error.go index 60b8d6ce5..7a1311b00 100644 --- a/utils/registry/errors/error.go +++ b/utils/registry/errors/error.go @@ -23,12 +23,13 @@ import ( // an Error instance will be returned type Error struct { StatusCode int + StatusText string Message string } // Error ... func (e Error) Error() string { - return fmt.Sprintf("%d %s", e.StatusCode, e.Message) + return fmt.Sprintf("%d %s %s", e.StatusCode, e.StatusText, e.Message) } // ParseError parses err, if err is type Error, convert it to Error diff --git a/utils/registry/registry.go b/utils/registry/registry.go index 1ee01892e..845c8e876 100644 --- a/utils/registry/registry.go +++ b/utils/registry/registry.go @@ -89,6 +89,10 @@ func (r *Registry) Catalog() ([]string, error) { resp, err := r.client.Do(req) if err != nil { + e, ok := isUnauthorizedError(err) + if ok { + return repos, e + } return repos, err } @@ -115,6 +119,7 @@ func (r *Registry) Catalog() ([]string, error) { return repos, errors.Error{ StatusCode: resp.StatusCode, + StatusText: resp.Status, Message: string(b), } } diff --git a/utils/registry/repository.go b/utils/registry/repository.go index 507634415..479d44408 100644 --- a/utils/registry/repository.go +++ b/utils/registry/repository.go @@ -72,6 +72,9 @@ func NewRepositoryWithCredential(name, endpoint string, credential auth.Credenti } client, err := newClient(endpoint, "", credential, "repository", name, "pull", "push") + if err != nil { + return nil, err + } repository := &Repository{ Name: name, @@ -108,6 +111,16 @@ func NewRepositoryWithUsername(name, endpoint, username string) (*Repository, er return repository, nil } +func isUnauthorizedError(err error) (error, bool) { + if strings.Contains(err.Error(), http.StatusText(http.StatusUnauthorized)) { + return errors.Error{ + StatusCode: http.StatusUnauthorized, + StatusText: http.StatusText(http.StatusUnauthorized), + }, true + } + return err, false +} + // ListTag ... func (r *Repository) ListTag() ([]string, error) { tags := []string{} @@ -118,6 +131,10 @@ func (r *Repository) ListTag() ([]string, error) { resp, err := r.client.Do(req) if err != nil { + e, ok := isUnauthorizedError(err) + if ok { + return tags, e + } return tags, err } @@ -141,9 +158,9 @@ func (r *Repository) ListTag() ([]string, error) { return tags, nil } - return tags, errors.Error{ StatusCode: resp.StatusCode, + StatusText: resp.Status, Message: string(b), } @@ -161,6 +178,11 @@ func (r *Repository) ManifestExist(reference string) (digest string, exist bool, resp, err := r.client.Do(req) if err != nil { + e, ok := isUnauthorizedError(err) + if ok { + err = e + return + } return } @@ -183,6 +205,7 @@ func (r *Repository) ManifestExist(reference string) (digest string, exist bool, err = errors.Error{ StatusCode: resp.StatusCode, + StatusText: resp.Status, Message: string(b), } return @@ -201,6 +224,11 @@ func (r *Repository) PullManifest(reference string, acceptMediaTypes []string) ( resp, err := r.client.Do(req) if err != nil { + e, ok := isUnauthorizedError(err) + if ok { + err = e + return + } return } @@ -219,6 +247,7 @@ func (r *Repository) PullManifest(reference string, acceptMediaTypes []string) ( err = errors.Error{ StatusCode: resp.StatusCode, + StatusText: resp.Status, Message: string(b), } @@ -236,6 +265,11 @@ func (r *Repository) PushManifest(reference, mediaType string, payload []byte) ( resp, err := r.client.Do(req) if err != nil { + e, ok := isUnauthorizedError(err) + if ok { + err = e + return + } return } @@ -253,6 +287,7 @@ func (r *Repository) PushManifest(reference, mediaType string, payload []byte) ( err = errors.Error{ StatusCode: resp.StatusCode, + StatusText: resp.Status, Message: string(b), } @@ -268,6 +303,10 @@ func (r *Repository) DeleteManifest(digest string) error { resp, err := r.client.Do(req) if err != nil { + e, ok := isUnauthorizedError(err) + if ok { + return e + } return err } @@ -284,6 +323,7 @@ func (r *Repository) DeleteManifest(digest string) error { return errors.Error{ StatusCode: resp.StatusCode, + StatusText: resp.Status, Message: string(b), } } @@ -298,6 +338,7 @@ func (r *Repository) DeleteTag(tag string) error { if !exist { return errors.Error{ StatusCode: http.StatusNotFound, + StatusText: http.StatusText(http.StatusNotFound), } } @@ -313,6 +354,10 @@ func (r *Repository) BlobExist(digest string) (bool, error) { resp, err := r.client.Do(req) if err != nil { + e, ok := isUnauthorizedError(err) + if ok { + return false, e + } return false, err } @@ -333,6 +378,7 @@ func (r *Repository) BlobExist(digest string) (bool, error) { return false, errors.Error{ StatusCode: resp.StatusCode, + StatusText: resp.Status, Message: string(b), } } @@ -346,6 +392,11 @@ func (r *Repository) PullBlob(digest string) (size int64, data []byte, err error resp, err := r.client.Do(req) if err != nil { + e, ok := isUnauthorizedError(err) + if ok { + err = e + return + } return } @@ -367,6 +418,7 @@ func (r *Repository) PullBlob(digest string) (size int64, data []byte, err error err = errors.Error{ StatusCode: resp.StatusCode, + StatusText: resp.Status, Message: string(b), } @@ -379,6 +431,11 @@ func (r *Repository) initiateBlobUpload(name string) (location, uploadUUID strin resp, err := r.client.Do(req) if err != nil { + e, ok := isUnauthorizedError(err) + if ok { + err = e + return + } return } @@ -397,6 +454,7 @@ func (r *Repository) initiateBlobUpload(name string) (location, uploadUUID strin err = errors.Error{ StatusCode: resp.StatusCode, + StatusText: resp.Status, Message: string(b), } @@ -411,6 +469,10 @@ func (r *Repository) monolithicBlobUpload(location, digest string, size int64, d resp, err := r.client.Do(req) if err != nil { + e, ok := isUnauthorizedError(err) + if ok { + return e + } return err } @@ -427,6 +489,7 @@ func (r *Repository) monolithicBlobUpload(location, digest string, size int64, d return errors.Error{ StatusCode: resp.StatusCode, + StatusText: resp.Status, Message: string(b), } } @@ -460,6 +523,10 @@ func (r *Repository) DeleteBlob(digest string) error { resp, err := r.client.Do(req) if err != nil { + e, ok := isUnauthorizedError(err) + if ok { + return e + } return err } @@ -476,6 +543,7 @@ func (r *Repository) DeleteBlob(digest string) error { return errors.Error{ StatusCode: resp.StatusCode, + StatusText: resp.Status, Message: string(b), } } diff --git a/utils/registry/repository_test.go b/utils/registry/repository_test.go new file mode 100644 index 000000000..32332171a --- /dev/null +++ b/utils/registry/repository_test.go @@ -0,0 +1,211 @@ +/* + Copyright (c) 2016 VMware, Inc. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package registry + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + "time" + + "github.com/vmware/harbor/utils/log" + "github.com/vmware/harbor/utils/registry/auth" + "github.com/vmware/harbor/utils/registry/errors" +) + +var ( + username string = "user" + password string = "P@ssw0rd" + repo string = "samalba/my-app" + tags tagResp = tagResp{Tags: []string{"1.0", "2.0", "3.0"}} + validToken string = "valid_token" + invalidToken string = "invalid_token" + credential auth.Credential + registryServer *httptest.Server + tokenServer *httptest.Server + repositoryClient *Repository +) + +type tagResp struct { + Tags []string `json:"tags"` +} + +func TestMain(m *testing.M) { + log.SetLevel(log.DebugLevel) + credential = auth.NewBasicAuthCredential(username, password) + + tokenServer = initTokenServer() + defer tokenServer.Close() + + registryServer = initRegistryServer() + defer registryServer.Close() + + os.Exit(m.Run()) +} + +func initRegistryServer() *httptest.Server { + mux := http.NewServeMux() + mux.HandleFunc("/v2/", servePing) + mux.HandleFunc(fmt.Sprintf("/v2/%s/tags/list", repo), serveTaglisting) + + return httptest.NewServer(mux) +} + +//response ping request: http://registry/v2 +func servePing(w http.ResponseWriter, r *http.Request) { + if !isTokenValid(r) { + challenge(w) + return + } +} + +func serveTaglisting(w http.ResponseWriter, r *http.Request) { + if !isTokenValid(r) { + challenge(w) + return + } + + if err := json.NewEncoder(w).Encode(tags); err != nil { + w.Write([]byte(err.Error())) + w.WriteHeader(http.StatusInternalServerError) + return + } + +} + +func isTokenValid(r *http.Request) bool { + valid := false + auth := r.Header.Get(http.CanonicalHeaderKey("Authorization")) + if len(auth) != 0 { + auth = strings.TrimSpace(auth) + index := strings.Index(auth, "Bearer") + token := auth[index+6:] + token = strings.TrimSpace(token) + if token == validToken { + valid = true + } + } + return valid +} + +func challenge(w http.ResponseWriter) { + challenge := "Bearer realm=\"" + tokenServer.URL + "/service/token\",service=\"token-service\"" + w.Header().Set("Www-Authenticate", challenge) + w.WriteHeader(http.StatusUnauthorized) + return +} + +func initTokenServer() *httptest.Server { + mux := http.NewServeMux() + mux.HandleFunc("/service/token", serveToken) + + return httptest.NewServer(mux) +} + +func serveToken(w http.ResponseWriter, r *http.Request) { + u, p, ok := r.BasicAuth() + if !ok || u != username || p != password { + w.WriteHeader(http.StatusUnauthorized) + return + } + + result := make(map[string]interface{}) + result["token"] = validToken + result["expires_in"] = 300 + result["issued_at"] = time.Now().Format(time.RFC3339) + + encoder := json.NewEncoder(w) + if err := encoder.Encode(result); err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(err.Error())) + return + } +} + +func TestListTag(t *testing.T) { + client, err := NewRepositoryWithCredential(repo, registryServer.URL, credential) + if err != nil { + t.Error(err) + } + + list, err := client.ListTag() + if err != nil { + t.Error(err) + return + } + if len(list) != len(tags.Tags) { + t.Errorf("expected length: %d, actual length: %d", len(tags.Tags), len(list)) + return + } + +} + +func TestListTagWithInvalidUser(t *testing.T) { + credential := auth.NewBasicAuthCredential("user", "test") + client, err := NewRepositoryWithCredential(repo, registryServer.URL, credential) + if err != nil { + t.Error(err) + } + + _, err = client.ListTag() + if err != nil { + e, ok := errors.ParseError(err) + if ok && e.StatusCode == http.StatusUnauthorized { + return + } + t.Error(err) + return + } +} + +/*tokenHandler := func(w http.ResponseWriter, r *http.Request) { + username, _, ok := r.BasicAuth() + if !ok { + w.WriteHeader(http.StatusUnauthorized) + return + } + + service := r.FormValue("service") + scopes := r.URL.Query()["scope"] + access := token_util.GetResourceActions(scopes) + + token, _, issuedAt, err := token_util.MakeToken(username, service, access) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + t.Error(err) + return + } + + result := make(map[string]interface{}) + result["token"] = token + result["expires_in"] = "dfsd" + result["issued_at"] = issuedAt.Format(time.RFC3339) + + encoder := json.NewEncoder(w) + if err = encoder.Encode(result); err != nil { + w.WriteHeader(http.StatusInternalServerError) + t.Error(err) + return + } +} + +tokenServer := httptest.NewServer(http.HandlerFunc(tokenHandler)) +defer tokenServer.Close() +*/