diff --git a/src/replication/adapter/awsecr/adapter.go b/src/replication/adapter/awsecr/adapter.go index 0ad569d3f..8a4b42d3e 100644 --- a/src/replication/adapter/awsecr/adapter.go +++ b/src/replication/adapter/awsecr/adapter.go @@ -16,15 +16,10 @@ package awsecr import ( "errors" - "net/http" "regexp" - "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/session" awsecrapi "github.com/aws/aws-sdk-go/service/ecr" - commonhttp "github.com/goharbor/harbor/src/common/http" "github.com/goharbor/harbor/src/lib/log" adp "github.com/goharbor/harbor/src/replication/adapter" "github.com/goharbor/harbor/src/replication/adapter/native" @@ -52,11 +47,16 @@ func newAdapter(registry *model.Registry) (*adapter, error) { if err != nil { return nil, err } - authorizer := NewAuth(region, registry.Credential.AccessKey, registry.Credential.AccessSecret, registry.Insecure) + svc, err := getAwsSvc( + region, registry.Credential.AccessKey, registry.Credential.AccessSecret, registry.Insecure, nil) + if err != nil { + return nil, err + } + authorizer := NewAuth(registry.Credential.AccessKey, svc) return &adapter{ registry: registry, Adapter: native.NewAdapterWithAuthorizer(registry, authorizer), - region: region, + cacheSvc: svc, }, nil } @@ -88,9 +88,8 @@ var ( type adapter struct { *native.Adapter - registry *model.Registry - region string - forceEndpoint *string + registry *model.Registry + cacheSvc *awsecrapi.ECR } func (*adapter) Info() (info *model.RegistryInfo, err error) { @@ -205,11 +204,6 @@ func getAdapterInfo() *model.AdapterPattern { // HealthCheck checks health status of a registry func (a *adapter) HealthCheck() (model.HealthStatus, error) { - if a.registry.Credential == nil || - len(a.registry.Credential.AccessKey) == 0 || len(a.registry.Credential.AccessSecret) == 0 { - log.Errorf("no credential to ping registry %s", a.registry.URL) - return model.Unhealthy, nil - } if err := a.Ping(); err != nil { log.Errorf("failed to ping registry %s: %v", a.registry.URL, err) return model.Unhealthy, nil @@ -242,33 +236,7 @@ func (a *adapter) PrepareForPush(resources []*model.Resource) error { } func (a *adapter) createRepository(repository string) error { - if a.registry.Credential == nil || - len(a.registry.Credential.AccessKey) == 0 || len(a.registry.Credential.AccessSecret) == 0 { - return errors.New("no credential ") - } - cred := credentials.NewStaticCredentials( - a.registry.Credential.AccessKey, - a.registry.Credential.AccessSecret, - "") - if a.region == "" { - return errors.New("no region parsed") - } - - config := &aws.Config{ - Credentials: cred, - Region: &a.region, - HTTPClient: &http.Client{ - Transport: commonhttp.GetHTTPTransportByInsecure(a.registry.Insecure), - }, - } - if a.forceEndpoint != nil { - config.Endpoint = a.forceEndpoint - } - sess := session.Must(session.NewSession(config)) - - svc := awsecrapi.New(sess) - - _, err := svc.CreateRepository(&awsecrapi.CreateRepositoryInput{ + _, err := a.cacheSvc.CreateRepository(&awsecrapi.CreateRepositoryInput{ RepositoryName: &repository, }) if err != nil { @@ -284,40 +252,7 @@ func (a *adapter) createRepository(repository string) error { // DeleteManifest ... func (a *adapter) DeleteManifest(repository, reference string) error { - // AWS doesn't implement standard OCI delete manifest API, so use it's sdk. - if a.registry.Credential == nil || - len(a.registry.Credential.AccessKey) == 0 || len(a.registry.Credential.AccessSecret) == 0 { - return errors.New("no credential ") - } - cred := credentials.NewStaticCredentials( - a.registry.Credential.AccessKey, - a.registry.Credential.AccessSecret, - "") - if a.region == "" { - return errors.New("no region parsed") - } - - var tr *http.Transport - if a.registry.Insecure { - tr = commonhttp.GetHTTPTransport(commonhttp.InsecureTransport) - } else { - tr = commonhttp.GetHTTPTransport(commonhttp.SecureTransport) - } - config := &aws.Config{ - Credentials: cred, - Region: &a.region, - HTTPClient: &http.Client{ - Transport: tr, - }, - } - if a.forceEndpoint != nil { - config.Endpoint = a.forceEndpoint - } - sess := session.Must(session.NewSession(config)) - - svc := awsecrapi.New(sess) - - _, err := svc.BatchDeleteImage(&awsecrapi.BatchDeleteImageInput{ + _, err := a.cacheSvc.BatchDeleteImage(&awsecrapi.BatchDeleteImageInput{ RepositoryName: &repository, ImageIds: []*awsecrapi.ImageIdentifier{{ImageTag: &reference}}, }) diff --git a/src/replication/adapter/awsecr/adapter_test.go b/src/replication/adapter/awsecr/adapter_test.go index ce196729b..1eb1d22f9 100644 --- a/src/replication/adapter/awsecr/adapter_test.go +++ b/src/replication/adapter/awsecr/adapter_test.go @@ -3,6 +3,8 @@ package awsecr import ( "errors" "fmt" + awsecrapi "github.com/aws/aws-sdk-go/service/ecr" + "github.com/stretchr/testify/require" "io" "io/ioutil" "net/http" @@ -149,17 +151,23 @@ func getMockAdapter(t *testing.T, hasCred, health bool) (*adapter, *httptest.Ser Type: model.RegistryTypeAwsEcr, URL: server.URL, } + + var svc *awsecrapi.ECR if hasCred { registry.Credential = &model.Credential{ AccessKey: "xxx", AccessSecret: "ppp", } + svc, _ = getAwsSvc( + "test-region", registry.Credential.AccessKey, registry.Credential.AccessSecret, registry.Insecure, &server.URL) + } else { + svc, _ = getAwsSvc( + "test-region", "", "", registry.Insecure, &server.URL) } return &adapter{ - registry: registry, - Adapter: native.NewAdapter(registry), - region: "test-region", - forceEndpoint: &server.URL, + registry: registry, + Adapter: native.NewAdapter(registry), + cacheSvc: svc, }, server } @@ -180,7 +188,7 @@ func TestAdapter_HealthCheck(t *testing.T) { status, err := a.HealthCheck() assert.Nil(t, err) assert.NotNil(t, status) - assert.EqualValues(t, model.Unhealthy, status) + assert.EqualValues(t, model.Healthy, status) a, s = getMockAdapter(t, true, false) defer s.Close() @@ -260,16 +268,18 @@ func TestAwsAuthCredential_Modify(t *testing.T) { }, ) defer server.Close() - a, _ := NewAuth("test-region", "xxx", "ppp", true).(*awsAuthCredential) - a.forceEndpoint = &server.URL + svc, err := getAwsSvc( + "test-region", "xxx", "ppp", true, &server.URL) + require.Nil(t, err) + a, _ := NewAuth("xxx", svc).(*awsAuthCredential) req := httptest.NewRequest(http.MethodGet, "https://1234.dkr.ecr.test-region.amazonaws.com/v2/", nil) - err := a.Modify(req) - assert.Nil(t, err) err = a.Modify(req) - assert.Nil(t, err) + require.Nil(t, err) + err = a.Modify(req) + require.Nil(t, err) time.Sleep(time.Second) err = a.Modify(req) - assert.Nil(t, err) + require.Nil(t, err) } var urlForBenchmark = []string{ diff --git a/src/replication/adapter/awsecr/auth.go b/src/replication/adapter/awsecr/auth.go index 5d77ab122..8729e8f6f 100644 --- a/src/replication/adapter/awsecr/auth.go +++ b/src/replication/adapter/awsecr/auth.go @@ -18,19 +18,19 @@ import ( "encoding/base64" "errors" "fmt" - "net/http" - "net/url" - "strings" - "time" - "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds" "github.com/aws/aws-sdk-go/aws/session" awsecrapi "github.com/aws/aws-sdk-go/service/ecr" commonhttp "github.com/goharbor/harbor/src/common/http" "github.com/goharbor/harbor/src/common/http/modifier" "github.com/goharbor/harbor/src/lib/log" + "net/http" + "net/url" + "strings" + "time" ) // Credential ... @@ -38,11 +38,8 @@ type Credential modifier.Modifier // Implements interface Credential type awsAuthCredential struct { - region string - accessKey string - accessSecret string - insecure bool - forceEndpoint *string + accessKey string + awssvc *awsecrapi.ECR cacheToken *cacheToken cacheExpired *time.Time @@ -91,36 +88,44 @@ func (a *awsAuthCredential) Modify(req *http.Request) error { return nil } -func (a *awsAuthCredential) getAuthorization() (string, string, string, *time.Time, error) { - log.Infof("Aws Ecr getAuthorization %s", a.accessKey) - cred := credentials.NewStaticCredentials( - a.accessKey, - a.accessSecret, - "") - +func getAwsSvc(region, accessKey, accessSecret string, insecure bool, forceEndpoint *string) (*awsecrapi.ECR, error) { + sess, err := session.NewSession() + if err != nil { + return nil, err + } + var cred *credentials.Credentials + log.Debugf("Aws Ecr getAuthorization %s", accessKey) + if accessKey != "" { + cred = credentials.NewStaticCredentials( + accessKey, + accessSecret, + "") + } else { + cred = ec2rolecreds.NewCredentials(sess) + } var tr *http.Transport - if a.insecure { + if insecure { tr = commonhttp.GetHTTPTransport(commonhttp.InsecureTransport) } else { tr = commonhttp.GetHTTPTransport(commonhttp.SecureTransport) } config := &aws.Config{ Credentials: cred, - Region: &a.region, + Region: ®ion, HTTPClient: &http.Client{ Transport: tr, }, } - if a.forceEndpoint != nil { - config.Endpoint = a.forceEndpoint - } - sess, err := session.NewSession(config) - if err != nil { - return "", "", "", nil, err + if forceEndpoint != nil { + config.Endpoint = forceEndpoint } - svc := awsecrapi.New(sess) + svc := awsecrapi.New(sess, config) + return svc, nil +} +func (a *awsAuthCredential) getAuthorization() (string, string, string, *time.Time, error) { + svc := a.awssvc result, err := svc.GetAuthorizationToken(nil) if err != nil { if aerr, ok := err.(awserr.Error); ok { @@ -161,11 +166,9 @@ func (a *awsAuthCredential) isTokenValid() bool { } // NewAuth new aws auth -func NewAuth(region, accessKey, accessSecret string, insecure bool) Credential { +func NewAuth(accessKey string, awssvc *awsecrapi.ECR) Credential { return &awsAuthCredential{ - region: region, - accessKey: accessKey, - accessSecret: accessSecret, - insecure: insecure, + accessKey: accessKey, + awssvc: awssvc, } }