add ut for registry client

This commit is contained in:
Wenkai Yin 2016-05-10 22:01:38 +08:00
parent f6a1b1c4e4
commit 9378d5a345
6 changed files with 290 additions and 4 deletions

View File

@ -46,8 +46,8 @@ func NewRequestAuthorizer(handlers []Handler, challenges []au.Challenge) *Reques
// ModifyRequest adds authorization to the request // ModifyRequest adds authorization to the request
func (r *RequestAuthorizer) ModifyRequest(req *http.Request) error { func (r *RequestAuthorizer) ModifyRequest(req *http.Request) error {
for _, handler := range r.handlers {
for _, challenge := range r.challenges { for _, challenge := range r.challenges {
for _, handler := range r.handlers {
if handler.Scheme() == challenge.Scheme { if handler.Scheme() == challenge.Scheme {
if err := handler.AuthorizeRequest(req, challenge.Parameters); err != nil { if err := handler.AuthorizeRequest(req, challenge.Parameters); err != nil {
return err return err

View File

@ -168,6 +168,7 @@ func (s *standardTokenHandler) generateToken(realm, service string, scopes []str
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
err = registry_errors.Error{ err = registry_errors.Error{
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
StatusText: resp.Status,
Message: string(b), Message: string(b),
} }
return return

View File

@ -23,12 +23,13 @@ import (
// an Error instance will be returned // an Error instance will be returned
type Error struct { type Error struct {
StatusCode int StatusCode int
StatusText string
Message string Message string
} }
// Error ... // Error ...
func (e Error) Error() string { func (e Error) Error() string {
return fmt.Sprintf("%d %s", e.StatusCode, e.Message) return fmt.Sprintf("%d %s %s", e.StatusCode, e.StatusText, e.Message)
} }
// ParseError parses err, if err is type Error, convert it to Error // ParseError parses err, if err is type Error, convert it to Error

View File

@ -89,6 +89,10 @@ func (r *Registry) Catalog() ([]string, error) {
resp, err := r.client.Do(req) resp, err := r.client.Do(req)
if err != nil { if err != nil {
e, ok := isUnauthorizedError(err)
if ok {
return repos, e
}
return repos, err return repos, err
} }
@ -115,6 +119,7 @@ func (r *Registry) Catalog() ([]string, error) {
return repos, errors.Error{ return repos, errors.Error{
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
StatusText: resp.Status,
Message: string(b), Message: string(b),
} }
} }

View File

@ -72,6 +72,9 @@ func NewRepositoryWithCredential(name, endpoint string, credential auth.Credenti
} }
client, err := newClient(endpoint, "", credential, "repository", name, "pull", "push") client, err := newClient(endpoint, "", credential, "repository", name, "pull", "push")
if err != nil {
return nil, err
}
repository := &Repository{ repository := &Repository{
Name: name, Name: name,
@ -108,6 +111,16 @@ func NewRepositoryWithUsername(name, endpoint, username string) (*Repository, er
return repository, nil return repository, nil
} }
func isUnauthorizedError(err error) (error, bool) {
if strings.Contains(err.Error(), http.StatusText(http.StatusUnauthorized)) {
return errors.Error{
StatusCode: http.StatusUnauthorized,
StatusText: http.StatusText(http.StatusUnauthorized),
}, true
}
return err, false
}
// ListTag ... // ListTag ...
func (r *Repository) ListTag() ([]string, error) { func (r *Repository) ListTag() ([]string, error) {
tags := []string{} tags := []string{}
@ -118,6 +131,10 @@ func (r *Repository) ListTag() ([]string, error) {
resp, err := r.client.Do(req) resp, err := r.client.Do(req)
if err != nil { if err != nil {
e, ok := isUnauthorizedError(err)
if ok {
return tags, e
}
return tags, err return tags, err
} }
@ -141,9 +158,9 @@ func (r *Repository) ListTag() ([]string, error) {
return tags, nil return tags, nil
} }
return tags, errors.Error{ return tags, errors.Error{
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
StatusText: resp.Status,
Message: string(b), Message: string(b),
} }
@ -161,6 +178,11 @@ func (r *Repository) ManifestExist(reference string) (digest string, exist bool,
resp, err := r.client.Do(req) resp, err := r.client.Do(req)
if err != nil { if err != nil {
e, ok := isUnauthorizedError(err)
if ok {
err = e
return
}
return return
} }
@ -183,6 +205,7 @@ func (r *Repository) ManifestExist(reference string) (digest string, exist bool,
err = errors.Error{ err = errors.Error{
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
StatusText: resp.Status,
Message: string(b), Message: string(b),
} }
return return
@ -201,6 +224,11 @@ func (r *Repository) PullManifest(reference string, acceptMediaTypes []string) (
resp, err := r.client.Do(req) resp, err := r.client.Do(req)
if err != nil { if err != nil {
e, ok := isUnauthorizedError(err)
if ok {
err = e
return
}
return return
} }
@ -219,6 +247,7 @@ func (r *Repository) PullManifest(reference string, acceptMediaTypes []string) (
err = errors.Error{ err = errors.Error{
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
StatusText: resp.Status,
Message: string(b), Message: string(b),
} }
@ -236,6 +265,11 @@ func (r *Repository) PushManifest(reference, mediaType string, payload []byte) (
resp, err := r.client.Do(req) resp, err := r.client.Do(req)
if err != nil { if err != nil {
e, ok := isUnauthorizedError(err)
if ok {
err = e
return
}
return return
} }
@ -253,6 +287,7 @@ func (r *Repository) PushManifest(reference, mediaType string, payload []byte) (
err = errors.Error{ err = errors.Error{
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
StatusText: resp.Status,
Message: string(b), Message: string(b),
} }
@ -268,6 +303,10 @@ func (r *Repository) DeleteManifest(digest string) error {
resp, err := r.client.Do(req) resp, err := r.client.Do(req)
if err != nil { if err != nil {
e, ok := isUnauthorizedError(err)
if ok {
return e
}
return err return err
} }
@ -284,6 +323,7 @@ func (r *Repository) DeleteManifest(digest string) error {
return errors.Error{ return errors.Error{
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
StatusText: resp.Status,
Message: string(b), Message: string(b),
} }
} }
@ -298,6 +338,7 @@ func (r *Repository) DeleteTag(tag string) error {
if !exist { if !exist {
return errors.Error{ return errors.Error{
StatusCode: http.StatusNotFound, StatusCode: http.StatusNotFound,
StatusText: http.StatusText(http.StatusNotFound),
} }
} }
@ -313,6 +354,10 @@ func (r *Repository) BlobExist(digest string) (bool, error) {
resp, err := r.client.Do(req) resp, err := r.client.Do(req)
if err != nil { if err != nil {
e, ok := isUnauthorizedError(err)
if ok {
return false, e
}
return false, err return false, err
} }
@ -333,6 +378,7 @@ func (r *Repository) BlobExist(digest string) (bool, error) {
return false, errors.Error{ return false, errors.Error{
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
StatusText: resp.Status,
Message: string(b), Message: string(b),
} }
} }
@ -346,6 +392,11 @@ func (r *Repository) PullBlob(digest string) (size int64, data []byte, err error
resp, err := r.client.Do(req) resp, err := r.client.Do(req)
if err != nil { if err != nil {
e, ok := isUnauthorizedError(err)
if ok {
err = e
return
}
return return
} }
@ -367,6 +418,7 @@ func (r *Repository) PullBlob(digest string) (size int64, data []byte, err error
err = errors.Error{ err = errors.Error{
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
StatusText: resp.Status,
Message: string(b), Message: string(b),
} }
@ -379,6 +431,11 @@ func (r *Repository) initiateBlobUpload(name string) (location, uploadUUID strin
resp, err := r.client.Do(req) resp, err := r.client.Do(req)
if err != nil { if err != nil {
e, ok := isUnauthorizedError(err)
if ok {
err = e
return
}
return return
} }
@ -397,6 +454,7 @@ func (r *Repository) initiateBlobUpload(name string) (location, uploadUUID strin
err = errors.Error{ err = errors.Error{
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
StatusText: resp.Status,
Message: string(b), Message: string(b),
} }
@ -411,6 +469,10 @@ func (r *Repository) monolithicBlobUpload(location, digest string, size int64, d
resp, err := r.client.Do(req) resp, err := r.client.Do(req)
if err != nil { if err != nil {
e, ok := isUnauthorizedError(err)
if ok {
return e
}
return err return err
} }
@ -427,6 +489,7 @@ func (r *Repository) monolithicBlobUpload(location, digest string, size int64, d
return errors.Error{ return errors.Error{
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
StatusText: resp.Status,
Message: string(b), Message: string(b),
} }
} }
@ -460,6 +523,10 @@ func (r *Repository) DeleteBlob(digest string) error {
resp, err := r.client.Do(req) resp, err := r.client.Do(req)
if err != nil { if err != nil {
e, ok := isUnauthorizedError(err)
if ok {
return e
}
return err return err
} }
@ -476,6 +543,7 @@ func (r *Repository) DeleteBlob(digest string) error {
return errors.Error{ return errors.Error{
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
StatusText: resp.Status,
Message: string(b), Message: string(b),
} }
} }

View File

@ -0,0 +1,211 @@
/*
Copyright (c) 2016 VMware, Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package registry
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"
"github.com/vmware/harbor/utils/log"
"github.com/vmware/harbor/utils/registry/auth"
"github.com/vmware/harbor/utils/registry/errors"
)
var (
username string = "user"
password string = "P@ssw0rd"
repo string = "samalba/my-app"
tags tagResp = tagResp{Tags: []string{"1.0", "2.0", "3.0"}}
validToken string = "valid_token"
invalidToken string = "invalid_token"
credential auth.Credential
registryServer *httptest.Server
tokenServer *httptest.Server
repositoryClient *Repository
)
type tagResp struct {
Tags []string `json:"tags"`
}
func TestMain(m *testing.M) {
log.SetLevel(log.DebugLevel)
credential = auth.NewBasicAuthCredential(username, password)
tokenServer = initTokenServer()
defer tokenServer.Close()
registryServer = initRegistryServer()
defer registryServer.Close()
os.Exit(m.Run())
}
func initRegistryServer() *httptest.Server {
mux := http.NewServeMux()
mux.HandleFunc("/v2/", servePing)
mux.HandleFunc(fmt.Sprintf("/v2/%s/tags/list", repo), serveTaglisting)
return httptest.NewServer(mux)
}
//response ping request: http://registry/v2
func servePing(w http.ResponseWriter, r *http.Request) {
if !isTokenValid(r) {
challenge(w)
return
}
}
func serveTaglisting(w http.ResponseWriter, r *http.Request) {
if !isTokenValid(r) {
challenge(w)
return
}
if err := json.NewEncoder(w).Encode(tags); err != nil {
w.Write([]byte(err.Error()))
w.WriteHeader(http.StatusInternalServerError)
return
}
}
func isTokenValid(r *http.Request) bool {
valid := false
auth := r.Header.Get(http.CanonicalHeaderKey("Authorization"))
if len(auth) != 0 {
auth = strings.TrimSpace(auth)
index := strings.Index(auth, "Bearer")
token := auth[index+6:]
token = strings.TrimSpace(token)
if token == validToken {
valid = true
}
}
return valid
}
func challenge(w http.ResponseWriter) {
challenge := "Bearer realm=\"" + tokenServer.URL + "/service/token\",service=\"token-service\""
w.Header().Set("Www-Authenticate", challenge)
w.WriteHeader(http.StatusUnauthorized)
return
}
func initTokenServer() *httptest.Server {
mux := http.NewServeMux()
mux.HandleFunc("/service/token", serveToken)
return httptest.NewServer(mux)
}
func serveToken(w http.ResponseWriter, r *http.Request) {
u, p, ok := r.BasicAuth()
if !ok || u != username || p != password {
w.WriteHeader(http.StatusUnauthorized)
return
}
result := make(map[string]interface{})
result["token"] = validToken
result["expires_in"] = 300
result["issued_at"] = time.Now().Format(time.RFC3339)
encoder := json.NewEncoder(w)
if err := encoder.Encode(result); err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(err.Error()))
return
}
}
func TestListTag(t *testing.T) {
client, err := NewRepositoryWithCredential(repo, registryServer.URL, credential)
if err != nil {
t.Error(err)
}
list, err := client.ListTag()
if err != nil {
t.Error(err)
return
}
if len(list) != len(tags.Tags) {
t.Errorf("expected length: %d, actual length: %d", len(tags.Tags), len(list))
return
}
}
func TestListTagWithInvalidUser(t *testing.T) {
credential := auth.NewBasicAuthCredential("user", "test")
client, err := NewRepositoryWithCredential(repo, registryServer.URL, credential)
if err != nil {
t.Error(err)
}
_, err = client.ListTag()
if err != nil {
e, ok := errors.ParseError(err)
if ok && e.StatusCode == http.StatusUnauthorized {
return
}
t.Error(err)
return
}
}
/*tokenHandler := func(w http.ResponseWriter, r *http.Request) {
username, _, ok := r.BasicAuth()
if !ok {
w.WriteHeader(http.StatusUnauthorized)
return
}
service := r.FormValue("service")
scopes := r.URL.Query()["scope"]
access := token_util.GetResourceActions(scopes)
token, _, issuedAt, err := token_util.MakeToken(username, service, access)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
t.Error(err)
return
}
result := make(map[string]interface{})
result["token"] = token
result["expires_in"] = "dfsd"
result["issued_at"] = issuedAt.Format(time.RFC3339)
encoder := json.NewEncoder(w)
if err = encoder.Encode(result); err != nil {
w.WriteHeader(http.StatusInternalServerError)
t.Error(err)
return
}
}
tokenServer := httptest.NewServer(http.HandlerFunc(tokenHandler))
defer tokenServer.Close()
*/