feat(beego): upgrade beego to v1.12 which support middleware (#10524)

1. Upgrade beego to v1.12.0
2. Add RequestID middleware to all HTTP requests.
3. Add Orm middleware to v2 and v2.0 APIs.
4. Remove OrmFilter from all HTTP requests.
5. Fix some test cases which cause panic in API controllers.
6. Enable XSRF for test cases of CommonController.
7. Imporve ReadOnly middleware.

Signed-off-by: He Weiwei <hweiwei@vmware.com>
This commit is contained in:
He Weiwei 2020-01-20 16:41:49 +08:00 committed by GitHub
parent 603cc0f5f3
commit 33dfa1ea11
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
115 changed files with 6616 additions and 1127 deletions

View File

@ -226,7 +226,7 @@ func (aj *AJAPI) getLog(id int64) {
// submit submits a job to job service per request
func (aj *AJAPI) submit(ajr *models.AdminJobReq) {
// when the schedule is saved as None without any schedule, just return 200 and do nothing.
if ajr.Schedule.Type == models.ScheduleNone {
if ajr.Schedule == nil || ajr.Schedule.Type == models.ScheduleNone {
return
}

View File

@ -20,7 +20,7 @@ func TestGCPost(t *testing.T) {
t.Error("Error occurred while add a admin job", err.Error())
t.Log(err)
} else {
assert.Equal(200, code, "Add adminjob status should be 200")
assert.Equal(201, code, "Add adminjob status should be 201")
}
}

View File

@ -1,16 +1,33 @@
// Copyright Project Harbor Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package api
import (
"encoding/json"
"fmt"
"github.com/goharbor/harbor/src/pkg/retention/dao"
"github.com/goharbor/harbor/src/pkg/retention/dao/models"
"github.com/goharbor/harbor/src/pkg/retention/policy"
"github.com/goharbor/harbor/src/pkg/retention/policy/rule"
"github.com/stretchr/testify/require"
"net/http"
"testing"
"time"
"github.com/goharbor/harbor/src/pkg/retention/dao"
"github.com/goharbor/harbor/src/pkg/retention/dao/models"
"github.com/goharbor/harbor/src/pkg/retention/mocks"
"github.com/goharbor/harbor/src/pkg/retention/policy"
"github.com/goharbor/harbor/src/pkg/retention/policy/rule"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
func TestGetMetadatas(t *testing.T) {
@ -30,6 +47,16 @@ func TestGetMetadatas(t *testing.T) {
}
func TestCreatePolicy(t *testing.T) {
// mock retention api controller
mockController := &mocks.APIController{}
mockController.On("CreateRetention", mock.AnythingOfType("*policy.Metadata")).Return(int64(1), nil)
controller := retentionController
retentionController = mockController
defer func() {
retentionController = controller
}()
p1 := &policy.Metadata{
Algorithm: "or",
Rules: []rule.Metadata{
@ -87,7 +114,7 @@ func TestCreatePolicy(t *testing.T) {
bodyJSON: p1,
credential: sysAdmin,
},
code: http.StatusOK,
code: http.StatusCreated,
},
{
request: &testingRequest{
@ -267,6 +294,23 @@ func TestPolicy(t *testing.T) {
require.Nil(t, err)
require.True(t, id > 0)
// mock retention api controller
mockController := &mocks.APIController{}
mockController.On("GetRetention", mock.AnythingOfType("int64")).Return(p, nil)
mockController.On("UpdateRetention", mock.AnythingOfType("*policy.Metadata")).Return(nil)
mockController.On("TriggerRetentionExec",
mock.AnythingOfType("int64"),
mock.AnythingOfType("string"),
mock.AnythingOfType("bool")).Return(int64(1), nil)
mockController.On("ListRetentionExecs", mock.AnythingOfType("int64"), mock.AnythingOfType("*q.Query")).Return(nil, nil)
mockController.On("GetTotalOfRetentionExecs", mock.AnythingOfType("int64")).Return(int64(0), nil)
controller := retentionController
retentionController = mockController
defer func() {
retentionController = controller
}()
cases := []*codeCheckingCase{
{
request: &testingRequest{
@ -408,7 +452,7 @@ func TestPolicy(t *testing.T) {
},
credential: sysAdmin,
},
code: http.StatusOK,
code: http.StatusCreated,
},
{
request: &testingRequest{

View File

@ -1,3 +1,17 @@
// Copyright Project Harbor Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package api
import (
@ -56,7 +70,7 @@ func (suite *ScanAllAPITestSuite) TestScanAllPost() {
// case 1: add a new scan all job
code, err := apiTest.AddScanAll(*admin, adminJob002)
require.NoError(suite.T(), err, "Error occurred while add a scan all job")
suite.Equal(200, code, "Add scan all status should be 200")
suite.Equal(201, code, "Add scan all status should be 200")
}
func (suite *ScanAllAPITestSuite) TestScanAllGet() {

View File

@ -15,23 +15,21 @@ package controllers
import (
"context"
"github.com/goharbor/harbor/src/core/filter"
"fmt"
"net/http"
"net/http/httptest"
// "net/url"
"os"
"path/filepath"
"runtime"
"testing"
"fmt"
"os"
"strings"
"testing"
"github.com/astaxie/beego"
"github.com/goharbor/harbor/src/common"
"github.com/goharbor/harbor/src/common/models"
utilstest "github.com/goharbor/harbor/src/common/utils/test"
"github.com/goharbor/harbor/src/core/config"
"github.com/goharbor/harbor/src/core/filter"
"github.com/goharbor/harbor/src/core/middlewares"
"github.com/stretchr/testify/assert"
)
@ -41,6 +39,7 @@ func init() {
dir := filepath.Dir(file)
dir = filepath.Join(dir, "..")
apppath, _ := filepath.Abs(dir)
beego.BConfig.WebConfig.EnableXSRF = true
beego.BConfig.WebConfig.Session.SessionOn = true
beego.TestBeegoInit(apppath)
beego.AddTemplateExt("htm")
@ -106,9 +105,6 @@ func TestAll(t *testing.T) {
err := middlewares.Init()
assert.Nil(err)
// Has to set to dev so that the xsrf panic can be rendered as 403
beego.BConfig.RunMode = beego.DEV
r, _ := http.NewRequest("POST", "/c/login", nil)
w := httptest.NewRecorder()
beego.BeeApp.Handlers.ServeHTTP(w, r)

View File

@ -1,32 +0,0 @@
// Copyright Project Harbor Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package filter
import (
"github.com/astaxie/beego/context"
o "github.com/astaxie/beego/orm"
"github.com/goharbor/harbor/src/internal/orm"
)
// OrmFilter set orm.Ormer instance to the context of the http.Request
func OrmFilter(ctx *context.Context) {
if ctx == nil || ctx.Request == nil {
return
}
// This is a temp workaround for beego bug: https://github.com/goharbor/harbor/issues/10446
// After we upgrading beego to the latest version and moving the filter to middleware,
// this workaround can be removed
*(ctx.Request) = *(ctx.Request.WithContext(orm.NewContext(ctx.Request.Context(), o.NewOrm())))
}

View File

@ -17,9 +17,11 @@ package main
import (
"encoding/gob"
"fmt"
"net/http"
"os"
"os/signal"
"strconv"
"strings"
"syscall"
"time"
@ -53,6 +55,8 @@ import (
"github.com/goharbor/harbor/src/pkg/types"
"github.com/goharbor/harbor/src/pkg/version"
"github.com/goharbor/harbor/src/replication"
"github.com/goharbor/harbor/src/server/middleware/orm"
"github.com/goharbor/harbor/src/server/middleware/requestid"
)
const (
@ -247,7 +251,6 @@ func main() {
filter.Init()
beego.InsertFilter("/api/*", beego.BeforeStatic, filter.SessionCheck)
beego.InsertFilter("/*", beego.BeforeRouter, filter.OrmFilter)
beego.InsertFilter("/*", beego.BeforeRouter, filter.SecurityFilter)
beego.InsertFilter("/*", beego.BeforeRouter, filter.ReadonlyFilter)
@ -288,6 +291,22 @@ func main() {
}
log.Infof("Version: %s, Git commit: %s", version.ReleaseVersion, version.GitCommit)
beego.Run()
middlewares := []beego.MiddleWare{
requestid.Middleware(),
orm.Middleware(legacyAPISkipper),
}
beego.RunWithMiddleWares("", middlewares...)
}
// legacyAPISkipper skip middleware for legacy APIs
func legacyAPISkipper(r *http.Request) bool {
for _, prefix := range []string{"/v2/", "/api/v2.0/"} {
if strings.HasPrefix(r.URL.Path, prefix) {
return false
}
}
return true
}

View File

@ -5,14 +5,13 @@ go 1.12
replace github.com/goharbor/harbor => ../
require (
github.com/Knetic/govaluate v3.0.0+incompatible // indirect
github.com/Masterminds/semver v1.4.2
github.com/Microsoft/go-winio v0.4.12 // indirect
github.com/Shopify/logrus-bugsnag v0.0.0-20171204204709-577dee27f20d // indirect
github.com/Unknwon/goconfig v0.0.0-20160216183935-5f601ca6ef4d // indirect
github.com/agl/ed25519 v0.0.0-20170116200512-5312a6153412 // indirect
github.com/aliyun/alibaba-cloud-sdk-go v0.0.0-20190726115642-cd293c93fd97
github.com/astaxie/beego v1.9.0
github.com/astaxie/beego v1.12.0
github.com/aws/aws-sdk-go v1.19.47
github.com/beego/i18n v0.0.0-20140604031826-e87155e8f0c0
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 // indirect
@ -71,6 +70,7 @@ require (
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
github.com/prometheus/client_golang v0.9.4 // indirect
github.com/robfig/cron v1.0.0
github.com/shiena/ansicolor v0.0.0-20151119151921-a422bbe96644 // indirect
github.com/sirupsen/logrus v1.4.1 // indirect
github.com/spf13/viper v1.4.0 // indirect
github.com/stretchr/testify v1.4.0

View File

@ -11,6 +11,7 @@ github.com/Masterminds/semver v1.4.2/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF0
github.com/Microsoft/go-winio v0.4.12 h1:xAfWHN1IrQ0NJ9TBC0KBZoqLjzDTr1ML+4MywiUOryc=
github.com/Microsoft/go-winio v0.4.12/go.mod h1:VhR8bwka0BXejwEJY73c50VrPtXAaKcyvVC4A4RozmA=
github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU=
github.com/OwnLocal/goes v1.0.0/go.mod h1:8rIFjBGTue3lCU0wplczcUgt9Gxgrkkrw7etMIcn8TM=
github.com/PuerkitoBio/purell v1.1.0/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0=
github.com/PuerkitoBio/purell v1.1.1 h1:WEQqlqaGbrPkxLJWfBwQmfEAE1Z7ONdDLqrN38tNFfI=
github.com/PuerkitoBio/purell v1.1.1/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0=
@ -34,13 +35,15 @@ github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5
github.com/asaskevich/govalidator v0.0.0-20180720115003-f9ffefc3facf/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY=
github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a h1:idn718Q4B6AGu/h5Sxe66HYVdqdGu2l9Iebqhi/AEoA=
github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY=
github.com/astaxie/beego v1.9.0 h1:tPzS+D1oCLi+SEb/TLNRNYpCjaMVfAGoy9OTLwS5ul4=
github.com/astaxie/beego v1.9.0/go.mod h1:0R4++1tUqERR0WYFWdfkcrsyoVBCG4DgpDGokT3yb+U=
github.com/astaxie/beego v1.12.0 h1:MRhVoeeye5N+Flul5PoVfD9CslfdoH+xqC/xvSQ5u2Y=
github.com/astaxie/beego v1.12.0/go.mod h1:fysx+LZNZKnvh4GED/xND7jWtjCR6HzydR2Hh2Im57o=
github.com/aws/aws-sdk-go v1.19.47 h1:ZEze0mpk8Fttrsz6UNLqhH/jRGYbMPfWFA2ILas4AmM=
github.com/aws/aws-sdk-go v1.19.47/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo=
github.com/baiyubin/aliyun-sts-go-sdk v0.0.0-20180326062324-cfa1a18b161f/go.mod h1:AuiFmCCPBSrqvVMvuqFuk0qogytodnVFVSN5CeJB8Gc=
github.com/beego/goyaml2 v0.0.0-20130207012346-5545475820dd/go.mod h1:1b+Y/CofkYwXMUU0OhQqGvsY2Bvgr4j6jfT699wyZKQ=
github.com/beego/i18n v0.0.0-20140604031826-e87155e8f0c0 h1:fQaDnUQvBXHHQdGBu9hz8nPznB4BeiPQokvmQVjmNEw=
github.com/beego/i18n v0.0.0-20140604031826-e87155e8f0c0/go.mod h1:KLeFCpAMq2+50NkXC8iiJxLLiiTfTqrGtKEVm+2fk7s=
github.com/beego/x2j v0.0.0-20131220205130-a0352aadc542/go.mod h1:kSeGC/p1AbBiEp5kat81+DSQrZenVBZXklMLaELspWU=
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
github.com/beorn7/perks v1.0.0 h1:HWo1m869IqiPhD389kmkxeTalrjNbbJTC8LXupb+sl0=
github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
@ -52,6 +55,7 @@ github.com/bmatcuk/doublestar v1.1.1 h1:YroD6BJCZBYx06yYFEWvUuKVWQn3vLLQAVmDmvTS
github.com/bmatcuk/doublestar v1.1.1/go.mod h1:UD6OnuiIn0yFxxA2le/rnRU1G4RaI4UvFv1sNto9p6w=
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY=
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4=
github.com/bradfitz/gomemcache v0.0.0-20180710155616-bc664df96737/go.mod h1:PmM6Mmwb0LSuEubjR8N7PtNe1KxZLtOUHtbeikc5h60=
github.com/bugsnag/bugsnag-go v1.5.2 h1:fdaGJJEReigPzSE6HajOhpJwE2IEP/TdHDHXKGeOJtc=
github.com/bugsnag/bugsnag-go v1.5.2/go.mod h1:2oa8nejYd4cQ/b0hMIopN0lCRxU0bueqREvZLWFrtK8=
github.com/bugsnag/panicwrap v1.2.0 h1:OzrKrRvXis8qEvOkfcxNcYbOd2O7xXS2nnKMEMABFQA=
@ -64,6 +68,7 @@ github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghf
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/cloudflare/cfssl v0.0.0-20190510060611-9c027c93ba9e h1:ZtyhUG4s94BMUCdgvRZySr/AXYL5CDcjxhIV/83xJog=
github.com/cloudflare/cfssl v0.0.0-20190510060611-9c027c93ba9e/go.mod h1:yMWuSON2oQp+43nFtAV/uvKQIFpSPerB57DCt9t8sSA=
github.com/cloudflare/golz4 v0.0.0-20150217214814-ef862a3cdc58/go.mod h1:EOBUe0h4xcZ5GoxqC5SDxFQ8gwyZPKQoEzownBlhI80=
github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkEiiKk=
github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE=
github.com/coreos/go-oidc v2.0.0+incompatible h1:+RStIopZ8wooMx+Vs5Bt8zMXxV1ABl5LbakNExNmZIg=
@ -71,6 +76,10 @@ github.com/coreos/go-oidc v2.0.0+incompatible/go.mod h1:CgnwVTmzoESiwO9qyAFEMiHo
github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk=
github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA=
github.com/couchbase/go-couchbase v0.0.0-20181122212707-3e9b6e1258bb/go.mod h1:TWI8EKQMs5u5jLKW/tsb9VwauIrMIxQG1r5fMsswK5U=
github.com/couchbase/gomemcached v0.0.0-20181122193126-5125a94a666c/go.mod h1:srVSlQLB8iXBVXHgnqemxUXqN6FCvClgCMPCsjBDR7c=
github.com/couchbase/goutils v0.0.0-20180530154633-e865a1461c8a/go.mod h1:BQwMFlJzDjFDG3DJUdU0KORxn88UlsOULuxLExMh3Hs=
github.com/cupcake/rdb v0.0.0-20161107195141-43ba34106c76/go.mod h1:vYwsqCOLxGiisLwp9rITslkFNpZD5rz43tf41QFkTWY=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@ -99,6 +108,9 @@ github.com/docker/libtrust v0.0.0-20160708172513-aabc10ec26b7/go.mod h1:cyGadeNE
github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs=
github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU=
github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I=
github.com/edsrzf/mmap-go v0.0.0-20170320065105-0bce6a688712/go.mod h1:YO35OhQPt3KJa3ryjFM5Bs14WD66h8eGKpfaBNrHW5M=
github.com/elazarl/go-bindata-assetfs v1.0.0 h1:G/bYguwHIzWq9ZoyUQqrjTmJbbYn3j3CKKpKinvZLFk=
github.com/elazarl/go-bindata-assetfs v1.0.0/go.mod h1:v+YaWX3bdea5J/mo8dSETolEo7R71Vk1u8bnjau5yw4=
github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 h1:Yzb9+7DPaBjB8zlTR87/ElzFsnQfuHnVUVqpZZIcV5Y=
github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5/go.mod h1:a2zkGnVExMxdzMo3M0Hi/3sEU+cWnZpSni0O6/Yb/P0=
github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I=
@ -163,6 +175,7 @@ github.com/go-openapi/validate v0.18.0/go.mod h1:Uh4HdOzKt19xGIGm1qHf/ofbX1YQ4Y+
github.com/go-openapi/validate v0.19.2/go.mod h1:1tRCw7m3jtI8eNWEEliiAqUIcBztB2KDnRCRMUi7GTA=
github.com/go-openapi/validate v0.19.3 h1:PAH/2DylwWcIU1s0Y7k3yNmeAgWOcKrNE2Q7Ww/kCg4=
github.com/go-openapi/validate v0.19.3/go.mod h1:90Vh6jjkTn+OT1Eefm0ZixWNFjhtOH7vS9k0lo6zwJo=
github.com/go-redis/redis v6.14.2+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA=
github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA=
github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w=
github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk=
@ -262,6 +275,7 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/pty v1.1.5/go.mod h1:9r2w37qlBe7rQ6e1fg1S/9xpWHSnaqNdHD3WcMdbPDA=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.1.0 h1:/5u4a+KGJptBRqGzPvYQL9p0d/tPR4S31+Tnzj9lEO4=
github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/magiconair/properties v1.8.0 h1:LLgXmsheXeRoUOBOjtwPQCWIYqM/LU1ayDtDePerRcY=
@ -336,6 +350,11 @@ github.com/robfig/cron v1.0.0 h1:slmQxIUH6U9ruw4XoJ7C2pyyx4yYeiHx8S9pNootHsM=
github.com/robfig/cron v1.0.0/go.mod h1:JGuDeoQd7Z6yL4zQhZ3OPEVHB7fL6Ka6skscFHfmt2k=
github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg=
github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0=
github.com/shiena/ansicolor v0.0.0-20151119151921-a422bbe96644 h1:X+yvsM2yrEktyI+b2qND5gpH8YhURn0k8OCaeRnkINo=
github.com/shiena/ansicolor v0.0.0-20151119151921-a422bbe96644/go.mod h1:nkxAfR/5quYxwPZhyDxgasBMnRtBZd0FCEpawpjMUFg=
github.com/siddontang/go v0.0.0-20180604090527-bdc77568d726/go.mod h1:3yhqj7WBBfRhbBlzyOC3gUxftwsU0u8gqevxwIHQpMw=
github.com/siddontang/ledisdb v0.0.0-20181029004158-becf5f38d373/go.mod h1:mF1DpOSOUiJRMR+FDqaqu3EBqrybQtrDDszLUZ6oxPg=
github.com/siddontang/rdb v0.0.0-20150307021120-fc89ed2e418d/go.mod h1:AMEsy7v5z92TR1JKMkLLoaOQk++LVnOKL3ScbJ8GNGA=
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/sirupsen/logrus v1.4.1 h1:GL2rEmy6nsikmW0r8opw9JIRScdMF5hA8cOYLH7In1k=
github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q=
@ -355,6 +374,7 @@ github.com/spf13/pflag v1.0.3 h1:zPAT6CGy6wXeQ7NtTnaTerfKOsV6V6F8agHXFiazDkg=
github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4=
github.com/spf13/viper v1.4.0 h1:yXHLWeravcrgGyFSyCgdYpXQ9dR9c/WED3pg1RhxqEU=
github.com/spf13/viper v1.4.0/go.mod h1:PTJ7Z/lr49W6bUbkmS1V3by4uWynFiR9p7+dSq/yZzE=
github.com/ssdb/gossdb v0.0.0-20180723034631-88f6b59b84ec/go.mod h1:QBvMkMya+gXctz3kmljlUCu/yB3GZ6oee+dUozsezQE=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1 h1:2vfRuCMp5sSVIDSqO8oNnWJq7mPa6KVP3iPIwFBuy8A=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
@ -365,12 +385,14 @@ github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/syndtr/goleveldb v0.0.0-20181127023241-353a9fca669c/go.mod h1:Z4AUp2Km+PwemOoO/VB5AOx9XSsIItzFjoJlOSiYmn0=
github.com/theupdateframework/notary v0.6.1 h1:7wshjstgS9x9F5LuB1L5mBI2xNMObWqjz+cjWoom6l0=
github.com/theupdateframework/notary v0.6.1/go.mod h1:MOfgIfmox8s7/7fduvB2xyPPMJCrjRLRizA8OFwpnKY=
github.com/tidwall/pretty v1.0.0 h1:HsD+QiTn7sK6flMKIvNmpqz1qrpP3Ps6jOKIKMooyg4=
github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk=
github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=
github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc=
github.com/wendal/errors v0.0.0-20130201093226-f66c77a7882b/go.mod h1:Q12BUT7DqIlHRmgv3RskH+UCM/4eqVMgI0EMmlSpAXc=
github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU=
github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q=
go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU=
@ -382,6 +404,7 @@ go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0=
go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20181127143415-eb0de9b17e85/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190320223903-b7391e95e576/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c h1:Vj5n4GlwjmQteupaxJ9+0FNOmBrHfq7vN4btdGoDZgI=

View File

@ -25,6 +25,8 @@ import (
"github.com/goharbor/harbor/src/pkg/scheduler"
)
// go:generate mockery -name APIController -case snake
// APIController to handle the requests related with retention
type APIController interface {
GetRetention(id int64) (*policy.Metadata, error)

View File

@ -0,0 +1,257 @@
// Code generated by mockery v1.0.0. DO NOT EDIT.
package mocks
import (
policy "github.com/goharbor/harbor/src/pkg/retention/policy"
q "github.com/goharbor/harbor/src/pkg/retention/q"
mock "github.com/stretchr/testify/mock"
retention "github.com/goharbor/harbor/src/pkg/retention"
)
// APIController is an autogenerated mock type for the APIController type
type APIController struct {
mock.Mock
}
// CreateRetention provides a mock function with given fields: p
func (_m *APIController) CreateRetention(p *policy.Metadata) (int64, error) {
ret := _m.Called(p)
var r0 int64
if rf, ok := ret.Get(0).(func(*policy.Metadata) int64); ok {
r0 = rf(p)
} else {
r0 = ret.Get(0).(int64)
}
var r1 error
if rf, ok := ret.Get(1).(func(*policy.Metadata) error); ok {
r1 = rf(p)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// DeleteRetention provides a mock function with given fields: id
func (_m *APIController) DeleteRetention(id int64) error {
ret := _m.Called(id)
var r0 error
if rf, ok := ret.Get(0).(func(int64) error); ok {
r0 = rf(id)
} else {
r0 = ret.Error(0)
}
return r0
}
// GetRetention provides a mock function with given fields: id
func (_m *APIController) GetRetention(id int64) (*policy.Metadata, error) {
ret := _m.Called(id)
var r0 *policy.Metadata
if rf, ok := ret.Get(0).(func(int64) *policy.Metadata); ok {
r0 = rf(id)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*policy.Metadata)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(int64) error); ok {
r1 = rf(id)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetRetentionExec provides a mock function with given fields: eid
func (_m *APIController) GetRetentionExec(eid int64) (*retention.Execution, error) {
ret := _m.Called(eid)
var r0 *retention.Execution
if rf, ok := ret.Get(0).(func(int64) *retention.Execution); ok {
r0 = rf(eid)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*retention.Execution)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(int64) error); ok {
r1 = rf(eid)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetRetentionExecTaskLog provides a mock function with given fields: taskID
func (_m *APIController) GetRetentionExecTaskLog(taskID int64) ([]byte, error) {
ret := _m.Called(taskID)
var r0 []byte
if rf, ok := ret.Get(0).(func(int64) []byte); ok {
r0 = rf(taskID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]byte)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(int64) error); ok {
r1 = rf(taskID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetTotalOfRetentionExecTasks provides a mock function with given fields: executionID
func (_m *APIController) GetTotalOfRetentionExecTasks(executionID int64) (int64, error) {
ret := _m.Called(executionID)
var r0 int64
if rf, ok := ret.Get(0).(func(int64) int64); ok {
r0 = rf(executionID)
} else {
r0 = ret.Get(0).(int64)
}
var r1 error
if rf, ok := ret.Get(1).(func(int64) error); ok {
r1 = rf(executionID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetTotalOfRetentionExecs provides a mock function with given fields: policyID
func (_m *APIController) GetTotalOfRetentionExecs(policyID int64) (int64, error) {
ret := _m.Called(policyID)
var r0 int64
if rf, ok := ret.Get(0).(func(int64) int64); ok {
r0 = rf(policyID)
} else {
r0 = ret.Get(0).(int64)
}
var r1 error
if rf, ok := ret.Get(1).(func(int64) error); ok {
r1 = rf(policyID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// ListRetentionExecTasks provides a mock function with given fields: executionID, query
func (_m *APIController) ListRetentionExecTasks(executionID int64, query *q.Query) ([]*retention.Task, error) {
ret := _m.Called(executionID, query)
var r0 []*retention.Task
if rf, ok := ret.Get(0).(func(int64, *q.Query) []*retention.Task); ok {
r0 = rf(executionID, query)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*retention.Task)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(int64, *q.Query) error); ok {
r1 = rf(executionID, query)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// ListRetentionExecs provides a mock function with given fields: policyID, query
func (_m *APIController) ListRetentionExecs(policyID int64, query *q.Query) ([]*retention.Execution, error) {
ret := _m.Called(policyID, query)
var r0 []*retention.Execution
if rf, ok := ret.Get(0).(func(int64, *q.Query) []*retention.Execution); ok {
r0 = rf(policyID, query)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*retention.Execution)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(int64, *q.Query) error); ok {
r1 = rf(policyID, query)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// OperateRetentionExec provides a mock function with given fields: eid, action
func (_m *APIController) OperateRetentionExec(eid int64, action string) error {
ret := _m.Called(eid, action)
var r0 error
if rf, ok := ret.Get(0).(func(int64, string) error); ok {
r0 = rf(eid, action)
} else {
r0 = ret.Error(0)
}
return r0
}
// TriggerRetentionExec provides a mock function with given fields: policyID, trigger, dryRun
func (_m *APIController) TriggerRetentionExec(policyID int64, trigger string, dryRun bool) (int64, error) {
ret := _m.Called(policyID, trigger, dryRun)
var r0 int64
if rf, ok := ret.Get(0).(func(int64, string, bool) int64); ok {
r0 = rf(policyID, trigger, dryRun)
} else {
r0 = ret.Get(0).(int64)
}
var r1 error
if rf, ok := ret.Get(1).(func(int64, string, bool) error); ok {
r1 = rf(policyID, trigger, dryRun)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// UpdateRetention provides a mock function with given fields: p
func (_m *APIController) UpdateRetention(p *policy.Metadata) error {
ret := _m.Called(p)
var r0 error
if rf, ok := ret.Get(0).(func(*policy.Metadata) error); ok {
r0 = rf(p)
} else {
r0 = ret.Error(0)
}
return r0
}

View File

@ -29,3 +29,23 @@ func WithMiddlewares(handler http.Handler, middlewares ...Middleware) http.Handl
}
return handler
}
// Skipper defines a function to skip middleware.
// Returning true skips processing the middleware.
type Skipper func(*http.Request) bool
// New make a middleware from fn which type is func(w http.ResponseWriter, r *http.Request, next http.Handler)
func New(fn func(http.ResponseWriter, *http.Request, http.Handler), skippers ...Skipper) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
for _, skipper := range skippers {
if skipper(r) {
next.ServeHTTP(w, r)
return
}
}
fn(w, r, next)
})
}
}

View File

@ -15,10 +15,11 @@
package middleware
import (
"github.com/stretchr/testify/suite"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/suite"
)
type middlewareTestSuite struct {
@ -46,6 +47,32 @@ func (m *middlewareTestSuite) TestWithMiddlewares() {
m.Equal("middleware1middleware2handler", record.Header().Get("key"))
}
func (m *middlewareTestSuite) TestNew() {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
f := func(w http.ResponseWriter, r *http.Request, next http.Handler) {
w.Header().Set("key", "value")
next.ServeHTTP(w, r)
}
req1 := httptest.NewRequest(http.MethodGet, "/req", nil)
rec1 := httptest.NewRecorder()
New(f)(next).ServeHTTP(rec1, req1)
m.Equal("value", rec1.Header().Get("key"))
req2 := httptest.NewRequest(http.MethodGet, "/req", nil)
rec2 := httptest.NewRecorder()
New(f, func(r *http.Request) bool { return r.URL.Path == "/req" })(next).ServeHTTP(rec2, req2)
m.Equal("", rec2.Header().Get("key"))
req3 := httptest.NewRequest(http.MethodGet, "/req3", nil)
rec3 := httptest.NewRecorder()
New(f, func(r *http.Request) bool { return r.URL.Path == "/req" })(next).ServeHTTP(rec3, req3)
m.Equal("value", rec3.Header().Get("key"))
}
func TestMiddlewareTestSuite(t *testing.T) {
suite.Run(t, &middlewareTestSuite{})
}

View File

@ -0,0 +1,55 @@
// Copyright Project Harbor Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"net/http"
o "github.com/astaxie/beego/orm"
"github.com/goharbor/harbor/src/internal/orm"
"github.com/goharbor/harbor/src/server/middleware"
)
// Config defines the config for Orm middleware.
type Config struct {
// Creator defines a function to create ormer
Creator func() o.Ormer
}
var (
// DefaultConfig default orm config
DefaultConfig = Config{
Creator: func() o.Ormer {
return o.NewOrm()
},
}
)
// Middleware middleware which add ormer to the http request context with default config
func Middleware(skippers ...middleware.Skipper) func(http.Handler) http.Handler {
return MiddlewareWithConfig(DefaultConfig, skippers...)
}
// MiddlewareWithConfig middleware which add ormer to the http request context with config
func MiddlewareWithConfig(config Config, skippers ...middleware.Skipper) func(http.Handler) http.Handler {
if config.Creator == nil {
config.Creator = DefaultConfig.Creator
}
return middleware.New(func(w http.ResponseWriter, r *http.Request, next http.Handler) {
ctx := orm.NewContext(r.Context(), config.Creator())
next.ServeHTTP(w, r.WithContext(ctx))
}, skippers...)
}

View File

@ -0,0 +1,53 @@
// Copyright Project Harbor Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"net/http"
"net/http/httptest"
"testing"
o "github.com/astaxie/beego/orm"
"github.com/goharbor/harbor/src/internal/orm"
"github.com/stretchr/testify/assert"
)
type mockOrmer struct {
o.Ormer
}
func TestOrm(t *testing.T) {
assert := assert.New(t)
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := orm.FromContext(r.Context())
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
} else {
w.WriteHeader(http.StatusOK)
}
})
req1 := httptest.NewRequest(http.MethodGet, "/req1", nil)
rec1 := httptest.NewRecorder()
next.ServeHTTP(rec1, req1)
assert.Equal(http.StatusInternalServerError, rec1.Code)
req2 := httptest.NewRequest(http.MethodGet, "/req2", nil)
rec2 := httptest.NewRecorder()
MiddlewareWithConfig(Config{Creator: func() o.Ormer { return &mockOrmer{} }})(next).ServeHTTP(rec2, req2)
assert.Equal(http.StatusOK, rec2.Code)
}

View File

@ -1,27 +0,0 @@
package middleware
import (
"github.com/goharbor/harbor/src/common/utils/log"
"github.com/goharbor/harbor/src/core/config"
internal_errors "github.com/goharbor/harbor/src/internal/error"
"net/http"
)
type readonlyHandler struct {
next http.Handler
}
// ReadOnly middleware reject request when harbor set to readonly
func ReadOnly() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if config.ReadOnly() {
log.Warningf("The request is prohibited in readonly mode, url is: %s", req.URL.Path)
pkgE := internal_errors.New(nil).WithCode(internal_errors.DENIED).WithMessage("The system is in read only mode. Any modification is prohibited.")
http.Error(rw, internal_errors.NewErrs(pkgE).Error(), http.StatusForbidden)
return
}
next.ServeHTTP(rw, req)
})
}
}

View File

@ -0,0 +1,78 @@
// Copyright Project Harbor Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package readonly
import (
"net/http"
"github.com/goharbor/harbor/src/common/utils/log"
"github.com/goharbor/harbor/src/core/config"
internal_errors "github.com/goharbor/harbor/src/internal/error"
"github.com/goharbor/harbor/src/server/middleware"
)
// Config defines the config for ReadOnly middleware.
type Config struct {
// ReadOnly defines a function to check whether is readonly mode for request
ReadOnly func(*http.Request) bool
}
var (
// DefaultConfig default readonly config
DefaultConfig = Config{
ReadOnly: func(r *http.Request) bool {
return config.ReadOnly()
},
}
// See more for safe method at https://developer.mozilla.org/en-US/docs/Glossary/safe
safeMethods = map[string]bool{
http.MethodGet: true,
http.MethodHead: true,
http.MethodOptions: true,
}
)
// safeMethodSkipper returns false when the request method is safe methods
func safeMethodSkipper(r *http.Request) bool {
return safeMethods[r.Method]
}
// Middleware middleware reject request when harbor set to readonly with default config
func Middleware(skippers ...middleware.Skipper) func(http.Handler) http.Handler {
return MiddlewareWithConfig(DefaultConfig, skippers...)
}
// MiddlewareWithConfig middleware reject request when harbor set to readonly with config
func MiddlewareWithConfig(config Config, skippers ...middleware.Skipper) func(http.Handler) http.Handler {
if len(skippers) == 0 {
skippers = []middleware.Skipper{safeMethodSkipper}
}
if config.ReadOnly == nil {
config.ReadOnly = DefaultConfig.ReadOnly
}
return middleware.New(func(w http.ResponseWriter, r *http.Request, next http.Handler) {
if config.ReadOnly(r) {
log.Warningf("The request is prohibited in readonly mode, url is: %s", r.URL.Path)
pkgE := internal_errors.New(nil).WithCode(internal_errors.DENIED).WithMessage("The system is in read only mode. Any modification is prohibited.")
http.Error(w, internal_errors.NewErrs(pkgE).Error(), http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
}, skippers...)
}

View File

@ -0,0 +1,55 @@
// Copyright Project Harbor Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package readonly
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
func TestReadOnly(t *testing.T) {
assert := assert.New(t)
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
readOnly := func(readOnly bool) Config {
return Config{
ReadOnly: func(*http.Request) bool {
return readOnly
},
}
}
req1 := httptest.NewRequest(http.MethodDelete, "/resource", nil)
rec1 := httptest.NewRecorder()
MiddlewareWithConfig(readOnly(true))(next).ServeHTTP(rec1, req1)
assert.Equal(http.StatusForbidden, rec1.Code)
req2 := httptest.NewRequest(http.MethodDelete, "/resource", nil)
rec2 := httptest.NewRecorder()
MiddlewareWithConfig(readOnly(false))(next).ServeHTTP(rec2, req2)
assert.Equal(http.StatusOK, rec2.Code)
// safe method
req3 := httptest.NewRequest(http.MethodGet, "/resource", nil)
rec3 := httptest.NewRecorder()
MiddlewareWithConfig(readOnly(true))(next).ServeHTTP(rec3, req3)
assert.Equal(http.StatusOK, rec3.Code)
}

View File

@ -1,50 +0,0 @@
package middleware
import (
"github.com/goharbor/harbor/src/common"
config2 "github.com/goharbor/harbor/src/common/config"
"github.com/goharbor/harbor/src/core/config"
"net/http"
"net/http/httptest"
"os"
"testing"
"github.com/stretchr/testify/assert"
)
func TestMain(m *testing.M) {
conf := map[string]interface{}{
common.ReadOnly: "true",
}
kp := &config2.PresetKeyProvider{Key: "naa4JtarA1Zsc3uY"}
config.InitWithSettings(conf, kp)
result := m.Run()
if result != 0 {
os.Exit(result)
}
}
func TestReadOnly(t *testing.T) {
assert := assert.New(t)
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
})
// delete
req := httptest.NewRequest(http.MethodDelete, "/readonly1", nil)
rec := httptest.NewRecorder()
ReadOnly()(next).ServeHTTP(rec, req)
assert.Equal(rec.Code, http.StatusForbidden)
update := map[string]interface{}{
common.ReadOnly: "false",
}
config.GetCfgManager().UpdateConfig(update)
req2 := httptest.NewRequest(http.MethodDelete, "/readonly2", nil)
rec2 := httptest.NewRecorder()
ReadOnly()(next).ServeHTTP(rec2, req2)
assert.Equal(rec2.Code, http.StatusOK)
}

View File

@ -12,29 +12,28 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package middleware
package requestid
import (
"net/http"
"github.com/goharbor/harbor/src/server/middleware"
"github.com/google/uuid"
)
// HeaderXRequestID X-Request-ID header
const HeaderXRequestID = "X-Request-ID"
// RequestID middleware which add X-Request-ID header in the http request when not exist
func RequestID() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rid := r.Header.Get(HeaderXRequestID)
if rid == "" {
rid = uuid.New().String()
r.Header.Set(HeaderXRequestID, rid)
}
// Middleware middleware which add X-Request-ID header in the http request when not exist
func Middleware(skippers ...middleware.Skipper) func(http.Handler) http.Handler {
return middleware.New(func(w http.ResponseWriter, r *http.Request, next http.Handler) {
rid := r.Header.Get(HeaderXRequestID)
if rid == "" {
rid = uuid.New().String()
r.Header.Set(HeaderXRequestID, rid)
}
w.Header().Set(HeaderXRequestID, rid)
next.ServeHTTP(w, r)
})
}
w.Header().Set(HeaderXRequestID, rid)
next.ServeHTTP(w, r)
}, skippers...)
}

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package middleware
package requestid
import (
"net/http"
@ -26,22 +26,22 @@ func TestRequestID(t *testing.T) {
assert := assert.New(t)
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.WriteHeader(http.StatusOK)
})
req1 := httptest.NewRequest(http.MethodGet, "/req1", nil)
rr1 := httptest.NewRecorder()
next.ServeHTTP(rr1, req1)
assert.Equal(rr1.Header().Get(HeaderXRequestID), "")
rec1 := httptest.NewRecorder()
next.ServeHTTP(rec1, req1)
assert.Equal("", rec1.Header().Get(HeaderXRequestID))
req2 := httptest.NewRequest(http.MethodGet, "/req2", nil)
rr2 := httptest.NewRecorder()
RequestID()(next).ServeHTTP(rr2, req2)
assert.NotEqual(rr2.Header().Get(HeaderXRequestID), "")
rec2 := httptest.NewRecorder()
Middleware()(next).ServeHTTP(rec2, req2)
assert.NotEqual("", rec2.Header().Get(HeaderXRequestID))
req3 := httptest.NewRequest(http.MethodGet, "/req2", nil)
req3 := httptest.NewRequest(http.MethodGet, "/req3", nil)
req3.Header.Add(HeaderXRequestID, "852803be-e5fe-499b-bbea-c9e5b5f43916")
rr3 := httptest.NewRecorder()
RequestID()(next).ServeHTTP(rr3, req3)
assert.Equal(rr3.Header().Get(HeaderXRequestID), "852803be-e5fe-499b-bbea-c9e5b5f43916")
rec3 := httptest.NewRecorder()
Middleware()(next).ServeHTTP(rec3, req3)
assert.Equal("852803be-e5fe-499b-bbea-c9e5b5f43916", rec3.Header().Get(HeaderXRequestID))
}

View File

@ -15,17 +15,18 @@
package registry
import (
"net/http"
"net/http/httputil"
"net/url"
"github.com/goharbor/harbor/src/pkg/project"
pkg_repo "github.com/goharbor/harbor/src/pkg/repository"
pkg_tag "github.com/goharbor/harbor/src/pkg/tag"
"github.com/goharbor/harbor/src/server/middleware"
"github.com/goharbor/harbor/src/server/middleware/manifestinfo"
"github.com/goharbor/harbor/src/server/middleware/readonly"
"github.com/goharbor/harbor/src/server/middleware/regtoken"
"github.com/goharbor/harbor/src/server/registry/blob"
"net/http"
"net/http/httputil"
"net/url"
"github.com/goharbor/harbor/src/server/registry/catalog"
"github.com/goharbor/harbor/src/server/registry/manifest"
"github.com/goharbor/harbor/src/server/registry/tag"
@ -53,18 +54,18 @@ func New(url *url.URL) http.Handler {
manifestRouter := rootRouter.Path("/v2/{name:.*}/manifests/{reference}").Subrouter()
manifestRouter.NewRoute().Methods(http.MethodGet).Handler(middleware.WithMiddlewares(manifest.NewHandler(project.Mgr, proxy), manifestinfo.Middleware(), regtoken.Middleware()))
manifestRouter.NewRoute().Methods(http.MethodHead).Handler(manifest.NewHandler(project.Mgr, proxy))
manifestRouter.NewRoute().Methods(http.MethodPut).Handler(middleware.WithMiddlewares(manifest.NewHandler(project.Mgr, proxy), middleware.ReadOnly()))
manifestRouter.NewRoute().Methods(http.MethodDelete).Handler(middleware.WithMiddlewares(manifest.NewHandler(project.Mgr, proxy), middleware.ReadOnly()))
manifestRouter.NewRoute().Methods(http.MethodPut).Handler(middleware.WithMiddlewares(manifest.NewHandler(project.Mgr, proxy), readonly.Middleware()))
manifestRouter.NewRoute().Methods(http.MethodDelete).Handler(middleware.WithMiddlewares(manifest.NewHandler(project.Mgr, proxy), readonly.Middleware()))
// handle blob
// as we need to apply middleware to the blob requests, so create a sub router to handle the blob APIs
blobRouter := rootRouter.PathPrefix("/v2/{name:.*}/blobs/").Subrouter()
blobRouter.NewRoute().Methods(http.MethodGet).Handler(blob.NewHandler(proxy))
blobRouter.NewRoute().Methods(http.MethodHead).Handler(blob.NewHandler(proxy))
blobRouter.NewRoute().Methods(http.MethodPost).Handler(middleware.WithMiddlewares(blob.NewHandler(proxy), middleware.ReadOnly()))
blobRouter.NewRoute().Methods(http.MethodPut).Handler(middleware.WithMiddlewares(blob.NewHandler(proxy), middleware.ReadOnly()))
blobRouter.NewRoute().Methods(http.MethodPatch).Handler(middleware.WithMiddlewares(blob.NewHandler(proxy), middleware.ReadOnly()))
blobRouter.NewRoute().Methods(http.MethodDelete).Handler(middleware.WithMiddlewares(blob.NewHandler(proxy), middleware.ReadOnly()))
blobRouter.NewRoute().Methods(http.MethodPost).Handler(middleware.WithMiddlewares(blob.NewHandler(proxy), readonly.Middleware()))
blobRouter.NewRoute().Methods(http.MethodPut).Handler(middleware.WithMiddlewares(blob.NewHandler(proxy), readonly.Middleware()))
blobRouter.NewRoute().Methods(http.MethodPatch).Handler(middleware.WithMiddlewares(blob.NewHandler(proxy), readonly.Middleware()))
blobRouter.NewRoute().Methods(http.MethodDelete).Handler(middleware.WithMiddlewares(blob.NewHandler(proxy), readonly.Middleware()))
// all other APIs are proxy to the backend docker registry
rootRouter.PathPrefix("/").Handler(proxy)

View File

@ -1,4 +0,0 @@
github.com/astaxie/beego/*/*:S1012
github.com/astaxie/beego/*:S1012
github.com/astaxie/beego/*/*:S1007
github.com/astaxie/beego/*:S1007

View File

@ -1,19 +1,26 @@
language: go
go:
- 1.6.4
- 1.7.5
- 1.8.1
- "1.11.x"
services:
- redis-server
- mysql
- postgresql
- memcached
env:
- ORM_DRIVER=sqlite3 ORM_SOURCE=$TRAVIS_BUILD_DIR/orm_test.db
- ORM_DRIVER=mysql ORM_SOURCE="root:@/orm_test?charset=utf8"
- ORM_DRIVER=postgres ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable"
global:
- GO_REPO_FULLNAME="github.com/astaxie/beego"
matrix:
- ORM_DRIVER=sqlite3 ORM_SOURCE=$TRAVIS_BUILD_DIR/orm_test.db
- ORM_DRIVER=postgres ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable"
before_install:
# link the local repo with ${GOPATH}/src/<namespace>/<repo>
- GO_REPO_NAMESPACE=${GO_REPO_FULLNAME%/*}
# relies on GOPATH to contain only one directory...
- mkdir -p ${GOPATH}/src/${GO_REPO_NAMESPACE}
- ln -sv ${TRAVIS_BUILD_DIR} ${GOPATH}/src/${GO_REPO_FULLNAME}
- cd ${GOPATH}/src/${GO_REPO_FULLNAME}
# get and build ssdb
- git clone git://github.com/ideawu/ssdb.git
- cd ssdb
- make
@ -23,10 +30,11 @@ install:
- go get github.com/go-sql-driver/mysql
- go get github.com/mattn/go-sqlite3
- go get github.com/bradfitz/gomemcache/memcache
- go get github.com/garyburd/redigo/redis
- go get github.com/gomodule/redigo/redis
- go get github.com/beego/x2j
- go get github.com/couchbase/go-couchbase
- go get github.com/beego/goyaml2
- go get gopkg.in/yaml.v2
- go get github.com/belogik/goes
- go get github.com/siddontang/ledisdb/config
- go get github.com/siddontang/ledisdb/ledis
@ -35,28 +43,32 @@ install:
- go get github.com/gogo/protobuf/proto
- go get github.com/Knetic/govaluate
- go get github.com/casbin/casbin
- go get -u honnef.co/go/tools/cmd/gosimple
- go get github.com/elazarl/go-bindata-assetfs
- go get github.com/OwnLocal/goes
- go get github.com/shiena/ansicolor
- go get -u honnef.co/go/tools/cmd/staticcheck
- go get -u github.com/mdempsky/unconvert
- go get -u github.com/gordonklaus/ineffassign
- go get -u github.com/golang/lint/golint
- go get -u github.com/go-redis/redis
before_script:
- psql --version
- sh -c "if [ '$ORM_DRIVER' = 'postgres' ]; then psql -c 'create database orm_test;' -U postgres; fi"
- sh -c "if [ '$ORM_DRIVER' = 'mysql' ]; then mysql -u root -e 'create database orm_test;'; fi"
- sh -c "if [ '$ORM_DRIVER' = 'sqlite' ]; then touch $TRAVIS_BUILD_DIR/orm_test.db; fi"
- sh -c "if [ $(go version) == *1.[5-9]* ]; then go get github.com/golang/lint/golint; golint ./...; fi"
- sh -c "if [ $(go version) == *1.[5-9]* ]; then go tool vet .; fi"
- sh -c "go get github.com/golang/lint/golint; golint ./...;"
- sh -c "go list ./... | grep -v vendor | xargs go vet -v"
- mkdir -p res/var
- ./ssdb/ssdb-server ./ssdb/ssdb.conf -d
after_script:
-killall -w ssdb-server
- killall -w ssdb-server
- rm -rf ./res/var/*
script:
- go test -v ./...
- gosimple -ignore "$(cat .gosimpleignore)" $(go list ./... | grep -v /vendor/)
- staticcheck -show-ignored -checks "-ST1017,-U1000,-ST1005,-S1034,-S1012,-SA4006,-SA6005,-SA1019,-SA1024"
- unconvert $(go list ./... | grep -v /vendor/)
- ineffassign .
- find . ! \( -path './vendor' -prune \) -type f -name '*.go' -print0 | xargs -0 gofmt -l -s
- golint ./...
addons:
postgresql: "9.4"
postgresql: "9.6"

View File

@ -1,8 +1,11 @@
# Beego [![Build Status](https://travis-ci.org/astaxie/beego.svg?branch=master)](https://travis-ci.org/astaxie/beego) [![GoDoc](http://godoc.org/github.com/astaxie/beego?status.svg)](http://godoc.org/github.com/astaxie/beego) [![Foundation](https://img.shields.io/badge/Golang-Foundation-green.svg)](http://golangfoundation.org)
# Beego [![Build Status](https://travis-ci.org/astaxie/beego.svg?branch=master)](https://travis-ci.org/astaxie/beego) [![GoDoc](http://godoc.org/github.com/astaxie/beego?status.svg)](http://godoc.org/github.com/astaxie/beego) [![Foundation](https://img.shields.io/badge/Golang-Foundation-green.svg)](http://golangfoundation.org) [![Go Report Card](https://goreportcard.com/badge/github.com/astaxie/beego)](https://goreportcard.com/report/github.com/astaxie/beego)
beego is used for rapid development of RESTful APIs, web apps and backend services in Go.
It is inspired by Tornado, Sinatra and Flask. beego has some Go-specific features such as interfaces and struct embedding.
Response time ranking: [web-frameworks](https://github.com/the-benchmarker/web-frameworks).
###### More info at [beego.me](http://beego.me).
## Quick Start

View File

@ -20,11 +20,10 @@ import (
"fmt"
"net/http"
"os"
"reflect"
"text/template"
"time"
"reflect"
"github.com/astaxie/beego/grace"
"github.com/astaxie/beego/logs"
"github.com/astaxie/beego/toolbox"
@ -35,7 +34,7 @@ import (
var beeAdminApp *adminApp
// FilterMonitorFunc is default monitor filter when admin module is enable.
// if this func returns, admin module records qbs for this request by condition of this function logic.
// if this func returns, admin module records qps for this request by condition of this function logic.
// usage:
// func MyFilterMonitor(method, requestPath string, t time.Duration, pattern string, statusCode int) bool {
// if method == "POST" {
@ -67,15 +66,27 @@ func init() {
// AdminIndex is the default http.Handler for admin module.
// it matches url pattern "/".
func adminIndex(rw http.ResponseWriter, r *http.Request) {
func adminIndex(rw http.ResponseWriter, _ *http.Request) {
execTpl(rw, map[interface{}]interface{}{}, indexTpl, defaultScriptsTpl)
}
// QpsIndex is the http.Handler for writing qbs statistics map result info in http.ResponseWriter.
// it's registered with url pattern "/qbs" in admin module.
func qpsIndex(rw http.ResponseWriter, r *http.Request) {
// QpsIndex is the http.Handler for writing qps statistics map result info in http.ResponseWriter.
// it's registered with url pattern "/qps" in admin module.
func qpsIndex(rw http.ResponseWriter, _ *http.Request) {
data := make(map[interface{}]interface{})
data["Content"] = toolbox.StatisticsMap.GetMap()
// do html escape before display path, avoid xss
if content, ok := (data["Content"]).(M); ok {
if resultLists, ok := (content["Data"]).([][]string); ok {
for i := range resultLists {
if len(resultLists[i]) > 0 {
resultLists[i][0] = template.HTMLEscapeString(resultLists[i][0])
}
}
}
}
execTpl(rw, data, qpsTpl, defaultScriptsTpl)
}
@ -92,7 +103,7 @@ func listConf(rw http.ResponseWriter, r *http.Request) {
data := make(map[interface{}]interface{})
switch command {
case "conf":
m := make(map[string]interface{})
m := make(M)
list("BConfig", BConfig, m)
m["AppConfigPath"] = appConfigPath
m["AppConfigProvider"] = appConfigProvider
@ -116,14 +127,14 @@ func listConf(rw http.ResponseWriter, r *http.Request) {
execTpl(rw, data, routerAndFilterTpl, defaultScriptsTpl)
case "filter":
var (
content = map[string]interface{}{
content = M{
"Fields": []string{
"Router Pattern",
"Filter Function",
},
}
filterTypes = []string{}
filterTypeData = make(map[string]interface{})
filterTypeData = make(M)
)
if BeeApp.Handlers.enableFilter {
@ -161,7 +172,7 @@ func listConf(rw http.ResponseWriter, r *http.Request) {
}
}
func list(root string, p interface{}, m map[string]interface{}) {
func list(root string, p interface{}, m M) {
pt := reflect.TypeOf(p)
pv := reflect.ValueOf(p)
if pt.Kind() == reflect.Ptr {
@ -184,11 +195,11 @@ func list(root string, p interface{}, m map[string]interface{}) {
}
// PrintTree prints all registered routers.
func PrintTree() map[string]interface{} {
func PrintTree() M {
var (
content = map[string]interface{}{}
content = M{}
methods = []string{}
methodsData = make(map[string]interface{})
methodsData = make(M)
)
for method, t := range BeeApp.Handlers.routers {
@ -279,12 +290,12 @@ func profIndex(rw http.ResponseWriter, r *http.Request) {
// Healthcheck is a http.Handler calling health checking and showing the result.
// it's in "/healthcheck" pattern in admin module.
func healthcheck(rw http.ResponseWriter, req *http.Request) {
func healthcheck(rw http.ResponseWriter, _ *http.Request) {
var (
result []string
data = make(map[interface{}]interface{})
resultList = new([][]string)
content = map[string]interface{}{
content = M{
"Fields": []string{"Name", "Message", "Status"},
}
)
@ -332,7 +343,7 @@ func taskStatus(rw http.ResponseWriter, req *http.Request) {
}
// List Tasks
content := make(map[string]interface{})
content := make(M)
resultList := new([][]string)
var fields = []string{
"Task Name",

View File

@ -15,17 +15,22 @@
package beego
import (
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/http/fcgi"
"os"
"path"
"strings"
"time"
"github.com/astaxie/beego/grace"
"github.com/astaxie/beego/logs"
"github.com/astaxie/beego/utils"
"golang.org/x/crypto/acme/autocert"
)
var (
@ -51,8 +56,11 @@ func NewApp() *App {
return app
}
// MiddleWare function for http.Handler
type MiddleWare func(http.Handler) http.Handler
// Run beego application.
func (app *App) Run() {
func (app *App) Run(mws ...MiddleWare) {
addr := BConfig.Listen.HTTPAddr
if BConfig.Listen.HTTPPort != 0 {
@ -94,6 +102,12 @@ func (app *App) Run() {
}
app.Server.Handler = app.Handlers
for i := len(mws) - 1; i >= 0; i-- {
if mws[i] == nil {
continue
}
app.Server.Handler = mws[i](app.Server.Handler)
}
app.Server.ReadTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second
app.Server.WriteTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second
app.Server.ErrorLog = logs.GetLogger("HTTP")
@ -102,9 +116,9 @@ func (app *App) Run() {
if BConfig.Listen.Graceful {
httpsAddr := BConfig.Listen.HTTPSAddr
app.Server.Addr = httpsAddr
if BConfig.Listen.EnableHTTPS {
if BConfig.Listen.EnableHTTPS || BConfig.Listen.EnableMutualHTTPS {
go func() {
time.Sleep(20 * time.Microsecond)
time.Sleep(1000 * time.Microsecond)
if BConfig.Listen.HTTPSPort != 0 {
httpsAddr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort)
app.Server.Addr = httpsAddr
@ -112,10 +126,27 @@ func (app *App) Run() {
server := grace.NewServer(httpsAddr, app.Handlers)
server.Server.ReadTimeout = app.Server.ReadTimeout
server.Server.WriteTimeout = app.Server.WriteTimeout
if err := server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil {
logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid()))
time.Sleep(100 * time.Microsecond)
endRunning <- true
if BConfig.Listen.EnableMutualHTTPS {
if err := server.ListenAndServeMutualTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile, BConfig.Listen.TrustCaFile); err != nil {
logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid()))
time.Sleep(100 * time.Microsecond)
endRunning <- true
}
} else {
if BConfig.Listen.AutoTLS {
m := autocert.Manager{
Prompt: autocert.AcceptTOS,
HostPolicy: autocert.HostWhitelist(BConfig.Listen.Domains...),
Cache: autocert.DirCache(BConfig.Listen.TLSCacheDir),
}
app.Server.TLSConfig = &tls.Config{GetCertificate: m.GetCertificate}
BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile = "", ""
}
if err := server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil {
logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid()))
time.Sleep(100 * time.Microsecond)
endRunning <- true
}
}
}()
}
@ -139,22 +170,44 @@ func (app *App) Run() {
}
// run normal mode
if BConfig.Listen.EnableHTTPS {
if BConfig.Listen.EnableHTTPS || BConfig.Listen.EnableMutualHTTPS {
go func() {
time.Sleep(20 * time.Microsecond)
time.Sleep(1000 * time.Microsecond)
if BConfig.Listen.HTTPSPort != 0 {
app.Server.Addr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort)
} else if BConfig.Listen.EnableHTTP {
BeeLogger.Info("Start https server error, confict with http.Please reset https port")
logs.Info("Start https server error, conflict with http. Please reset https port")
return
}
logs.Info("https server Running on https://%s", app.Server.Addr)
if BConfig.Listen.AutoTLS {
m := autocert.Manager{
Prompt: autocert.AcceptTOS,
HostPolicy: autocert.HostWhitelist(BConfig.Listen.Domains...),
Cache: autocert.DirCache(BConfig.Listen.TLSCacheDir),
}
app.Server.TLSConfig = &tls.Config{GetCertificate: m.GetCertificate}
BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile = "", ""
} else if BConfig.Listen.EnableMutualHTTPS {
pool := x509.NewCertPool()
data, err := ioutil.ReadFile(BConfig.Listen.TrustCaFile)
if err != nil {
logs.Info("MutualHTTPS should provide TrustCaFile")
return
}
pool.AppendCertsFromPEM(data)
app.Server.TLSConfig = &tls.Config{
ClientCAs: pool,
ClientAuth: tls.RequireAndVerifyClientCert,
}
}
if err := app.Server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil {
logs.Critical("ListenAndServeTLS: ", err)
time.Sleep(100 * time.Microsecond)
endRunning <- true
}
}()
}
if BConfig.Listen.EnableHTTP {
go func() {
@ -207,6 +260,84 @@ func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *A
return BeeApp
}
// UnregisterFixedRoute unregisters the route with the specified fixedRoute. It is particularly useful
// in web applications that inherit most routes from a base webapp via the underscore
// import, and aim to overwrite only certain paths.
// The method parameter can be empty or "*" for all HTTP methods, or a particular
// method type (e.g. "GET" or "POST") for selective removal.
//
// Usage (replace "GET" with "*" for all methods):
// beego.UnregisterFixedRoute("/yourpreviouspath", "GET")
// beego.Router("/yourpreviouspath", yourControllerAddress, "get:GetNewPage")
func UnregisterFixedRoute(fixedRoute string, method string) *App {
subPaths := splitPath(fixedRoute)
if method == "" || method == "*" {
for m := range HTTPMETHOD {
if _, ok := BeeApp.Handlers.routers[m]; !ok {
continue
}
if BeeApp.Handlers.routers[m].prefix == strings.Trim(fixedRoute, "/ ") {
findAndRemoveSingleTree(BeeApp.Handlers.routers[m])
continue
}
findAndRemoveTree(subPaths, BeeApp.Handlers.routers[m], m)
}
return BeeApp
}
// Single HTTP method
um := strings.ToUpper(method)
if _, ok := BeeApp.Handlers.routers[um]; ok {
if BeeApp.Handlers.routers[um].prefix == strings.Trim(fixedRoute, "/ ") {
findAndRemoveSingleTree(BeeApp.Handlers.routers[um])
return BeeApp
}
findAndRemoveTree(subPaths, BeeApp.Handlers.routers[um], um)
}
return BeeApp
}
func findAndRemoveTree(paths []string, entryPointTree *Tree, method string) {
for i := range entryPointTree.fixrouters {
if entryPointTree.fixrouters[i].prefix == paths[0] {
if len(paths) == 1 {
if len(entryPointTree.fixrouters[i].fixrouters) > 0 {
// If the route had children subtrees, remove just the functional leaf,
// to allow children to function as before
if len(entryPointTree.fixrouters[i].leaves) > 0 {
entryPointTree.fixrouters[i].leaves[0] = nil
entryPointTree.fixrouters[i].leaves = entryPointTree.fixrouters[i].leaves[1:]
}
} else {
// Remove the *Tree from the fixrouters slice
entryPointTree.fixrouters[i] = nil
if i == len(entryPointTree.fixrouters)-1 {
entryPointTree.fixrouters = entryPointTree.fixrouters[:i]
} else {
entryPointTree.fixrouters = append(entryPointTree.fixrouters[:i], entryPointTree.fixrouters[i+1:len(entryPointTree.fixrouters)]...)
}
}
return
}
findAndRemoveTree(paths[1:], entryPointTree.fixrouters[i], method)
}
}
}
func findAndRemoveSingleTree(entryPointTree *Tree) {
if entryPointTree == nil {
return
}
if len(entryPointTree.fixrouters) > 0 {
// If the route had children subtrees, remove just the functional leaf,
// to allow children to function as before
if len(entryPointTree.leaves) > 0 {
entryPointTree.leaves[0] = nil
entryPointTree.leaves = entryPointTree.leaves[1:]
}
}
}
// Include will generate router file in the router/xxx.go from the controller's comments
// usage:
// beego.Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{})

View File

@ -23,7 +23,7 @@ import (
const (
// VERSION represent beego web framework version.
VERSION = "1.9.0"
VERSION = "1.12.0"
// DEV is for develop
DEV = "dev"
@ -31,7 +31,10 @@ const (
PROD = "prod"
)
//hook function to run
// M is Map shortcut
type M map[string]interface{}
// Hook function to run
type hookfunc func() error
var (
@ -62,11 +65,29 @@ func Run(params ...string) {
if len(strs) > 1 && strs[1] != "" {
BConfig.Listen.HTTPPort, _ = strconv.Atoi(strs[1])
}
BConfig.Listen.Domains = params
}
BeeApp.Run()
}
// RunWithMiddleWares Run beego application with middlewares.
func RunWithMiddleWares(addr string, mws ...MiddleWare) {
initBeforeHTTPRun()
strs := strings.Split(addr, ":")
if len(strs) > 0 && strs[0] != "" {
BConfig.Listen.HTTPAddr = strs[0]
BConfig.Listen.Domains = []string{strs[0]}
}
if len(strs) > 1 && strs[1] != "" {
BConfig.Listen.HTTPPort, _ = strconv.Atoi(strs[1])
}
BeeApp.Run(mws...)
}
func initBeforeHTTPRun() {
//init hooks
AddAPPStartHook(

View File

@ -52,7 +52,7 @@ Configure like this:
## Redis adapter
Redis adapter use the [redigo](http://github.com/garyburd/redigo) client.
Redis adapter use the [redigo](http://github.com/gomodule/redigo) client.
Configure like this:

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// Package cache provide a Cache interface and some implemetn engine
// Package cache provide a Cache interface and some implement engine
// Usage:
//
// import(

View File

@ -62,11 +62,14 @@ func NewFileCache() Cache {
}
// StartAndGC will start and begin gc for file cache.
// the config need to be like {CachePath:"/cache","FileSuffix":".bin","DirectoryLevel":2,"EmbedExpiry":0}
// the config need to be like {CachePath:"/cache","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"0"}
func (fc *FileCache) StartAndGC(config string) error {
var cfg map[string]string
json.Unmarshal([]byte(config), &cfg)
cfg := make(map[string]string)
err := json.Unmarshal([]byte(config), &cfg)
if err != nil {
return err
}
if _, ok := cfg["CachePath"]; !ok {
cfg["CachePath"] = FileCachePath
}
@ -142,12 +145,12 @@ func (fc *FileCache) GetMulti(keys []string) []interface{} {
// Put value into file cache.
// timeout means how long to keep this file, unit of ms.
// if timeout equals FileCacheEmbedExpiry(default is 0), cache this item forever.
// if timeout equals fc.EmbedExpiry(default is 0), cache this item forever.
func (fc *FileCache) Put(key string, val interface{}, timeout time.Duration) error {
gob.Register(val)
item := FileCacheItem{Data: val}
if timeout == FileCacheEmbedExpiry {
if timeout == time.Duration(fc.EmbedExpiry) {
item.Expired = time.Now().Add((86400 * 365 * 10) * time.Second) // ten years
} else {
item.Expired = time.Now().Add(timeout)
@ -179,7 +182,7 @@ func (fc *FileCache) Incr(key string) error {
} else {
incr = data.(int) + 1
}
fc.Put(key, incr, FileCacheEmbedExpiry)
fc.Put(key, incr, time.Duration(fc.EmbedExpiry))
return nil
}
@ -192,7 +195,7 @@ func (fc *FileCache) Decr(key string) error {
} else {
decr = data.(int) - 1
}
fc.Put(key, decr, FileCacheEmbedExpiry)
fc.Put(key, decr, time.Duration(fc.EmbedExpiry))
return nil
}

View File

@ -110,25 +110,25 @@ func (bc *MemoryCache) Delete(name string) error {
// Incr increase cache counter in memory.
// it supports int,int32,int64,uint,uint32,uint64.
func (bc *MemoryCache) Incr(key string) error {
bc.RLock()
defer bc.RUnlock()
bc.Lock()
defer bc.Unlock()
itm, ok := bc.items[key]
if !ok {
return errors.New("key not exist")
}
switch itm.val.(type) {
switch val := itm.val.(type) {
case int:
itm.val = itm.val.(int) + 1
itm.val = val + 1
case int32:
itm.val = itm.val.(int32) + 1
itm.val = val + 1
case int64:
itm.val = itm.val.(int64) + 1
itm.val = val + 1
case uint:
itm.val = itm.val.(uint) + 1
itm.val = val + 1
case uint32:
itm.val = itm.val.(uint32) + 1
itm.val = val + 1
case uint64:
itm.val = itm.val.(uint64) + 1
itm.val = val + 1
default:
return errors.New("item val is not (u)int (u)int32 (u)int64")
}
@ -137,34 +137,34 @@ func (bc *MemoryCache) Incr(key string) error {
// Decr decrease counter in memory.
func (bc *MemoryCache) Decr(key string) error {
bc.RLock()
defer bc.RUnlock()
bc.Lock()
defer bc.Unlock()
itm, ok := bc.items[key]
if !ok {
return errors.New("key not exist")
}
switch itm.val.(type) {
switch val := itm.val.(type) {
case int:
itm.val = itm.val.(int) - 1
itm.val = val - 1
case int64:
itm.val = itm.val.(int64) - 1
itm.val = val - 1
case int32:
itm.val = itm.val.(int32) - 1
itm.val = val - 1
case uint:
if itm.val.(uint) > 0 {
itm.val = itm.val.(uint) - 1
if val > 0 {
itm.val = val - 1
} else {
return errors.New("item val is less than 0")
}
case uint32:
if itm.val.(uint32) > 0 {
itm.val = itm.val.(uint32) - 1
if val > 0 {
itm.val = val - 1
} else {
return errors.New("item val is less than 0")
}
case uint64:
if itm.val.(uint64) > 0 {
itm.val = itm.val.(uint64) - 1
if val > 0 {
itm.val = val - 1
} else {
return errors.New("item val is less than 0")
}
@ -203,13 +203,17 @@ func (bc *MemoryCache) StartAndGC(config string) error {
dur := time.Duration(cf["interval"]) * time.Second
bc.Every = cf["interval"]
bc.dur = dur
go bc.vaccuum()
go bc.vacuum()
return nil
}
// check expiration.
func (bc *MemoryCache) vaccuum() {
if bc.Every < 1 {
func (bc *MemoryCache) vacuum() {
bc.RLock()
every := bc.Every
bc.RUnlock()
if every < 1 {
return
}
for {

View File

@ -14,9 +14,9 @@
// Package redis for cache provider
//
// depend on github.com/garyburd/redigo/redis
// depend on github.com/gomodule/redigo/redis
//
// go install github.com/garyburd/redigo/redis
// go install github.com/gomodule/redigo/redis
//
// Usage:
// import(
@ -32,12 +32,14 @@ package redis
import (
"encoding/json"
"errors"
"fmt"
"strconv"
"time"
"github.com/garyburd/redigo/redis"
"github.com/gomodule/redigo/redis"
"github.com/astaxie/beego/cache"
"strings"
)
var (
@ -52,6 +54,7 @@ type Cache struct {
dbNum int
key string
password string
maxIdle int
}
// NewRedisCache create new redis cache with default collection name.
@ -59,14 +62,23 @@ func NewRedisCache() cache.Cache {
return &Cache{key: DefaultKey}
}
// actually do the redis cmds
// actually do the redis cmds, args[0] must be the key name.
func (rc *Cache) do(commandName string, args ...interface{}) (reply interface{}, err error) {
if len(args) < 1 {
return nil, errors.New("missing required arguments")
}
args[0] = rc.associate(args[0])
c := rc.p.Get()
defer c.Close()
return c.Do(commandName, args...)
}
// associate with config key.
func (rc *Cache) associate(originKey interface{}) string {
return fmt.Sprintf("%s:%s", rc.key, originKey)
}
// Get cache from redis.
func (rc *Cache) Get(key string) interface{} {
if v, err := rc.do("GET", key); err == nil {
@ -77,57 +89,28 @@ func (rc *Cache) Get(key string) interface{} {
// GetMulti get cache from redis.
func (rc *Cache) GetMulti(keys []string) []interface{} {
size := len(keys)
var rv []interface{}
c := rc.p.Get()
defer c.Close()
var err error
var args []interface{}
for _, key := range keys {
err = c.Send("GET", key)
if err != nil {
goto ERROR
}
args = append(args, rc.associate(key))
}
if err = c.Flush(); err != nil {
goto ERROR
values, err := redis.Values(c.Do("MGET", args...))
if err != nil {
return nil
}
for i := 0; i < size; i++ {
if v, err := c.Receive(); err == nil {
rv = append(rv, v.([]byte))
} else {
rv = append(rv, err)
}
}
return rv
ERROR:
rv = rv[0:0]
for i := 0; i < size; i++ {
rv = append(rv, nil)
}
return rv
return values
}
// Put put cache to redis.
func (rc *Cache) Put(key string, val interface{}, timeout time.Duration) error {
var err error
if _, err = rc.do("SETEX", key, int64(timeout/time.Second), val); err != nil {
return err
}
if _, err = rc.do("HSET", rc.key, key, true); err != nil {
return err
}
_, err := rc.do("SETEX", key, int64(timeout/time.Second), val)
return err
}
// Delete delete cache in redis.
func (rc *Cache) Delete(key string) error {
var err error
if _, err = rc.do("DEL", key); err != nil {
return err
}
_, err = rc.do("HDEL", rc.key, key)
_, err := rc.do("DEL", key)
return err
}
@ -137,11 +120,6 @@ func (rc *Cache) IsExist(key string) bool {
if err != nil {
return false
}
if !v {
if _, err = rc.do("HDEL", rc.key, key); err != nil {
return false
}
}
return v
}
@ -159,16 +137,17 @@ func (rc *Cache) Decr(key string) error {
// ClearAll clean all cache in redis. delete this redis collection.
func (rc *Cache) ClearAll() error {
cachedKeys, err := redis.Strings(rc.do("HKEYS", rc.key))
c := rc.p.Get()
defer c.Close()
cachedKeys, err := redis.Strings(c.Do("KEYS", rc.key+":*"))
if err != nil {
return err
}
for _, str := range cachedKeys {
if _, err = rc.do("DEL", str); err != nil {
if _, err = c.Do("DEL", str); err != nil {
return err
}
}
_, err = rc.do("DEL", rc.key)
return err
}
@ -186,16 +165,28 @@ func (rc *Cache) StartAndGC(config string) error {
if _, ok := cf["conn"]; !ok {
return errors.New("config has no conn key")
}
// Format redis://<password>@<host>:<port>
cf["conn"] = strings.Replace(cf["conn"], "redis://", "", 1)
if i := strings.Index(cf["conn"], "@"); i > -1 {
cf["password"] = cf["conn"][0:i]
cf["conn"] = cf["conn"][i+1:]
}
if _, ok := cf["dbNum"]; !ok {
cf["dbNum"] = "0"
}
if _, ok := cf["password"]; !ok {
cf["password"] = ""
}
if _, ok := cf["maxIdle"]; !ok {
cf["maxIdle"] = "3"
}
rc.key = cf["key"]
rc.conninfo = cf["conn"]
rc.dbNum, _ = strconv.Atoi(cf["dbNum"])
rc.password = cf["password"]
rc.maxIdle, _ = strconv.Atoi(cf["maxIdle"])
rc.connectInit()
@ -229,7 +220,7 @@ func (rc *Cache) connectInit() {
}
// initialize a new pool
rc.p = &redis.Pool{
MaxIdle: 3,
MaxIdle: rc.maxIdle,
IdleTimeout: 180 * time.Second,
Dial: dialFunc,
}

View File

@ -49,22 +49,27 @@ type Config struct {
// Listen holds for http and https related config
type Listen struct {
Graceful bool // Graceful means use graceful module to start the server
ServerTimeOut int64
ListenTCP4 bool
EnableHTTP bool
HTTPAddr string
HTTPPort int
EnableHTTPS bool
HTTPSAddr string
HTTPSPort int
HTTPSCertFile string
HTTPSKeyFile string
EnableAdmin bool
AdminAddr string
AdminPort int
EnableFcgi bool
EnableStdIo bool // EnableStdIo works with EnableFcgi Use FCGI via standard I/O
Graceful bool // Graceful means use graceful module to start the server
ServerTimeOut int64
ListenTCP4 bool
EnableHTTP bool
HTTPAddr string
HTTPPort int
AutoTLS bool
Domains []string
TLSCacheDir string
EnableHTTPS bool
EnableMutualHTTPS bool
HTTPSAddr string
HTTPSPort int
HTTPSCertFile string
HTTPSKeyFile string
TrustCaFile string
EnableAdmin bool
AdminAddr string
AdminPort int
EnableFcgi bool
EnableStdIo bool // EnableStdIo works with EnableFcgi Use FCGI via standard I/O
}
// WebConfig holds web related config
@ -96,16 +101,18 @@ type SessionConfig struct {
SessionAutoSetCookie bool
SessionDomain string
SessionDisableHTTPOnly bool // used to allow for cross domain cookies/javascript cookies.
SessionEnableSidInHTTPHeader bool // enable store/get the sessionId into/from http headers
SessionEnableSidInHTTPHeader bool // enable store/get the sessionId into/from http headers
SessionNameInHTTPHeader string
SessionEnableSidInURLQuery bool // enable get the sessionId from Url Query params
SessionEnableSidInURLQuery bool // enable get the sessionId from Url Query params
}
// LogConfig holds Log related config
type LogConfig struct {
AccessLogs bool
FileLineNum bool
Outputs map[string]string // Store Adaptor : config
AccessLogs bool
EnableStaticLogs bool //log static files requests default: false
AccessLogsFormat string //access log format: JSON_FORMAT, APACHE_FORMAT or empty string
FileLineNum bool
Outputs map[string]string // Store Adaptor : config
}
var (
@ -134,9 +141,13 @@ func init() {
if err != nil {
panic(err)
}
appConfigPath = filepath.Join(workPath, "conf", "app.conf")
var filename = "app.conf"
if os.Getenv("BEEGO_RUNMODE") != "" {
filename = os.Getenv("BEEGO_RUNMODE") + ".app.conf"
}
appConfigPath = filepath.Join(workPath, "conf", filename)
if !utils.FileExists(appConfigPath) {
appConfigPath = filepath.Join(AppPath, "conf", "app.conf")
appConfigPath = filepath.Join(AppPath, "conf", filename)
if !utils.FileExists(appConfigPath) {
AppConfig = &beegoAppConfig{innerConfig: config.NewFakeConfig()}
return
@ -175,13 +186,18 @@ func recoverPanic(ctx *context.Context) {
if BConfig.RunMode == DEV && BConfig.EnableErrorsRender {
showErr(err, ctx, stack)
}
if ctx.Output.Status != 0 {
ctx.ResponseWriter.WriteHeader(ctx.Output.Status)
} else {
ctx.ResponseWriter.WriteHeader(500)
}
}
}
func newBConfig() *Config {
return &Config{
AppName: "beego",
RunMode: DEV,
RunMode: PROD,
RouterCaseSensitive: true,
ServerName: "beegoServer:" + VERSION,
RecoverPanic: true,
@ -196,6 +212,9 @@ func newBConfig() *Config {
ServerTimeOut: 0,
ListenTCP4: false,
EnableHTTP: true,
AutoTLS: false,
Domains: []string{},
TLSCacheDir: ".",
HTTPAddr: "",
HTTPPort: 8080,
EnableHTTPS: false,
@ -233,15 +252,17 @@ func newBConfig() *Config {
SessionCookieLifeTime: 0, //set cookie default is the browser life
SessionAutoSetCookie: true,
SessionDomain: "",
SessionEnableSidInHTTPHeader: false, // enable store/get the sessionId into/from http headers
SessionEnableSidInHTTPHeader: false, // enable store/get the sessionId into/from http headers
SessionNameInHTTPHeader: "Beegosessionid",
SessionEnableSidInURLQuery: false, // enable get the sessionId from Url Query params
SessionEnableSidInURLQuery: false, // enable get the sessionId from Url Query params
},
},
Log: LogConfig{
AccessLogs: false,
FileLineNum: true,
Outputs: map[string]string{"console": ""},
AccessLogs: false,
EnableStaticLogs: false,
AccessLogsFormat: "APACHE_FORMAT",
FileLineNum: true,
Outputs: map[string]string{"console": ""},
},
}
}

View File

@ -150,12 +150,12 @@ func ExpandValueEnv(value string) (realValue string) {
}
key := ""
defalutV := ""
defaultV := ""
// value start with "${"
for i := 2; i < vLen; i++ {
if value[i] == '|' && (i+1 < vLen && value[i+1] == '|') {
key = value[2:i]
defalutV = value[i+2 : vLen-1] // other string is default value.
defaultV = value[i+2 : vLen-1] // other string is default value.
break
} else if value[i] == '}' {
key = value[2:i]
@ -165,7 +165,7 @@ func ExpandValueEnv(value string) (realValue string) {
realValue = os.Getenv(key)
if realValue == "" {
realValue = defalutV
realValue = defaultV
}
return

View File

@ -126,7 +126,7 @@ func (c *fakeConfigContainer) SaveConfigFile(filename string) error {
var _ Configer = new(fakeConfigContainer)
// NewFakeConfig return a fake Congiger
// NewFakeConfig return a fake Configer
func NewFakeConfig() Configer {
return &fakeConfigContainer{
data: make(map[string]string),

View File

@ -78,15 +78,37 @@ func (ini *IniConfig) parseData(dir string, data []byte) (*IniConfigContainer, e
}
}
section := defaultSection
tmpBuf := bytes.NewBuffer(nil)
for {
line, _, err := buf.ReadLine()
if err == io.EOF {
tmpBuf.Reset()
shouldBreak := false
for {
tmp, isPrefix, err := buf.ReadLine()
if err == io.EOF {
shouldBreak = true
break
}
//It might be a good idea to throw a error on all unknonw errors?
if _, ok := err.(*os.PathError); ok {
return nil, err
}
tmpBuf.Write(tmp)
if isPrefix {
continue
}
if !isPrefix {
break
}
}
if shouldBreak {
break
}
//It might be a good idea to throw a error on all unknonw errors?
if _, ok := err.(*os.PathError); ok {
return nil, err
}
line := tmpBuf.Bytes()
line = bytes.TrimSpace(line)
if bytes.Equal(line, bEmpty) {
continue
@ -215,7 +237,7 @@ func (c *IniConfigContainer) Bool(key string) (bool, error) {
}
// DefaultBool returns the boolean value for a given key.
// if err != nil return defaltval
// if err != nil return defaultval
func (c *IniConfigContainer) DefaultBool(key string, defaultval bool) bool {
v, err := c.Bool(key)
if err != nil {
@ -230,7 +252,7 @@ func (c *IniConfigContainer) Int(key string) (int, error) {
}
// DefaultInt returns the integer value for a given key.
// if err != nil return defaltval
// if err != nil return defaultval
func (c *IniConfigContainer) DefaultInt(key string, defaultval int) int {
v, err := c.Int(key)
if err != nil {
@ -245,7 +267,7 @@ func (c *IniConfigContainer) Int64(key string) (int64, error) {
}
// DefaultInt64 returns the int64 value for a given key.
// if err != nil return defaltval
// if err != nil return defaultval
func (c *IniConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
v, err := c.Int64(key)
if err != nil {
@ -260,7 +282,7 @@ func (c *IniConfigContainer) Float(key string) (float64, error) {
}
// DefaultFloat returns the float64 value for a given key.
// if err != nil return defaltval
// if err != nil return defaultval
func (c *IniConfigContainer) DefaultFloat(key string, defaultval float64) float64 {
v, err := c.Float(key)
if err != nil {
@ -275,7 +297,7 @@ func (c *IniConfigContainer) String(key string) string {
}
// DefaultString returns the string value for a given key.
// if err != nil return defaltval
// if err != nil return defaultval
func (c *IniConfigContainer) DefaultString(key string, defaultval string) string {
v := c.String(key)
if v == "" {
@ -295,7 +317,7 @@ func (c *IniConfigContainer) Strings(key string) []string {
}
// DefaultStrings returns the []string value for a given key.
// if err != nil return defaltval
// if err != nil return defaultval
func (c *IniConfigContainer) DefaultStrings(key string, defaultval []string) []string {
v := c.Strings(key)
if v == nil {
@ -314,7 +336,7 @@ func (c *IniConfigContainer) GetSection(section string) (map[string]string, erro
// SaveConfigFile save the config into file.
//
// BUG(env): The environment variable config item will be saved with real value in SaveConfigFile Funcation.
// BUG(env): The environment variable config item will be saved with real value in SaveConfigFile Function.
func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) {
// Write configuration file by filename.
f, err := os.Create(filename)

View File

@ -101,7 +101,7 @@ func (c *JSONConfigContainer) Int(key string) (int, error) {
}
// DefaultInt returns the integer value for a given key.
// if err != nil return defaltval
// if err != nil return defaultval
func (c *JSONConfigContainer) DefaultInt(key string, defaultval int) int {
if v, err := c.Int(key); err == nil {
return v
@ -122,7 +122,7 @@ func (c *JSONConfigContainer) Int64(key string) (int64, error) {
}
// DefaultInt64 returns the int64 value for a given key.
// if err != nil return defaltval
// if err != nil return defaultval
func (c *JSONConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
if v, err := c.Int64(key); err == nil {
return v
@ -143,7 +143,7 @@ func (c *JSONConfigContainer) Float(key string) (float64, error) {
}
// DefaultFloat returns the float64 value for a given key.
// if err != nil return defaltval
// if err != nil return defaultval
func (c *JSONConfigContainer) DefaultFloat(key string, defaultval float64) float64 {
if v, err := c.Float(key); err == nil {
return v
@ -163,7 +163,7 @@ func (c *JSONConfigContainer) String(key string) string {
}
// DefaultString returns the string value for a given key.
// if err != nil return defaltval
// if err != nil return defaultval
func (c *JSONConfigContainer) DefaultString(key string, defaultval string) string {
// TODO FIXME should not use "" to replace non existence
if v := c.String(key); v != "" {
@ -182,7 +182,7 @@ func (c *JSONConfigContainer) Strings(key string) []string {
}
// DefaultStrings returns the []string value for a given key.
// if err != nil return defaltval
// if err != nil return defaultval
func (c *JSONConfigContainer) DefaultStrings(key string, defaultval []string) []string {
if v := c.Strings(key); v != nil {
return v

View File

@ -38,6 +38,14 @@ import (
"github.com/astaxie/beego/utils"
)
//commonly used mime-types
const (
ApplicationJSON = "application/json"
ApplicationXML = "application/xml"
ApplicationYAML = "application/x-yaml"
TextXML = "text/xml"
)
// NewContext return the Context with Input and Output
func NewContext() *Context {
return &Context{
@ -193,6 +201,7 @@ type Response struct {
http.ResponseWriter
Started bool
Status int
Elapsed time.Duration
}
func (r *Response) reset(rw http.ResponseWriter) {
@ -244,3 +253,11 @@ func (r *Response) CloseNotify() <-chan bool {
}
return nil
}
// Pusher http.Pusher
func (r *Response) Pusher() (pusher http.Pusher) {
if pusher, ok := r.ResponseWriter.(http.Pusher); ok {
return pusher
}
return nil
}

View File

@ -20,12 +20,14 @@ import (
"errors"
"io"
"io/ioutil"
"net"
"net/http"
"net/url"
"reflect"
"regexp"
"strconv"
"strings"
"sync"
"github.com/astaxie/beego/session"
)
@ -36,6 +38,7 @@ var (
acceptsHTMLRegex = regexp.MustCompile(`(text/html|application/xhtml\+xml)(?:,|$)`)
acceptsXMLRegex = regexp.MustCompile(`(application/xml|text/xml)(?:,|$)`)
acceptsJSONRegex = regexp.MustCompile(`(application/json)(?:,|$)`)
acceptsYAMLRegex = regexp.MustCompile(`(application/x-yaml)(?:,|$)`)
maxParam = 50
)
@ -47,6 +50,7 @@ type BeegoInput struct {
pnames []string
pvalues []string
data map[interface{}]interface{} // store some values in this context when calling context in filter or controller.
dataLock sync.RWMutex
RequestBody []byte
RunMethod string
RunController reflect.Type
@ -115,9 +119,8 @@ func (input *BeegoInput) Domain() string {
// if no host info in request, return localhost.
func (input *BeegoInput) Host() string {
if input.Context.Request.Host != "" {
hostParts := strings.Split(input.Context.Request.Host, ":")
if len(hostParts) > 0 {
return hostParts[0]
if hostPart, _, err := net.SplitHostPort(input.Context.Request.Host); err == nil {
return hostPart
}
return input.Context.Request.Host
}
@ -204,22 +207,27 @@ func (input *BeegoInput) AcceptsJSON() bool {
return acceptsJSONRegex.MatchString(input.Header("Accept"))
}
// AcceptsYAML Checks if request accepts json response
func (input *BeegoInput) AcceptsYAML() bool {
return acceptsYAMLRegex.MatchString(input.Header("Accept"))
}
// IP returns request client ip.
// if in proxy, return first proxy id.
// if error, return 127.0.0.1.
// if error, return RemoteAddr.
func (input *BeegoInput) IP() string {
ips := input.Proxy()
if len(ips) > 0 && ips[0] != "" {
rip := strings.Split(ips[0], ":")
return rip[0]
}
ip := strings.Split(input.Context.Request.RemoteAddr, ":")
if len(ip) > 0 {
if ip[0] != "[" {
return ip[0]
rip, _, err := net.SplitHostPort(ips[0])
if err != nil {
rip = ips[0]
}
return rip
}
return "127.0.0.1"
if ip, _, err := net.SplitHostPort(input.Context.Request.RemoteAddr); err == nil {
return ip
}
return input.Context.Request.RemoteAddr
}
// Proxy returns proxy client ips slice.
@ -253,9 +261,8 @@ func (input *BeegoInput) SubDomains() string {
// Port returns request client port.
// when error or empty, return 80.
func (input *BeegoInput) Port() int {
parts := strings.Split(input.Context.Request.Host, ":")
if len(parts) == 2 {
port, _ := strconv.Atoi(parts[1])
if _, portPart, err := net.SplitHostPort(input.Context.Request.Host); err == nil {
port, _ := strconv.Atoi(portPart)
return port
}
return 80
@ -373,6 +380,8 @@ func (input *BeegoInput) CopyBody(MaxMemory int64) []byte {
// Data return the implicit data in the input
func (input *BeegoInput) Data() map[interface{}]interface{} {
input.dataLock.Lock()
defer input.dataLock.Unlock()
if input.data == nil {
input.data = make(map[interface{}]interface{})
}
@ -381,6 +390,8 @@ func (input *BeegoInput) Data() map[interface{}]interface{} {
// GetData returns the stored data in this context.
func (input *BeegoInput) GetData(key interface{}) interface{} {
input.dataLock.Lock()
defer input.dataLock.Unlock()
if v, ok := input.data[key]; ok {
return v
}
@ -390,6 +401,8 @@ func (input *BeegoInput) GetData(key interface{}) interface{} {
// SetData stores data with given key in this context.
// This data are only available in this context.
func (input *BeegoInput) SetData(key, val interface{}) {
input.dataLock.Lock()
defer input.dataLock.Unlock()
if input.data == nil {
input.data = make(map[interface{}]interface{})
}

View File

@ -30,6 +30,8 @@ import (
"strconv"
"strings"
"time"
yaml "gopkg.in/yaml.v2"
)
// BeegoOutput does work for sending response header.
@ -182,8 +184,8 @@ func errorRenderer(err error) Renderer {
}
// JSON writes json to response body.
// if coding is true, it converts utf-8 to \u0000 type.
func (output *BeegoOutput) JSON(data interface{}, hasIndent bool, coding bool) error {
// if encoding is true, it converts utf-8 to \u0000 type.
func (output *BeegoOutput) JSON(data interface{}, hasIndent bool, encoding bool) error {
output.Header("Content-Type", "application/json; charset=utf-8")
var content []byte
var err error
@ -196,12 +198,25 @@ func (output *BeegoOutput) JSON(data interface{}, hasIndent bool, coding bool) e
http.Error(output.Context.ResponseWriter, err.Error(), http.StatusInternalServerError)
return err
}
if coding {
if encoding {
content = []byte(stringsToJSON(string(content)))
}
return output.Body(content)
}
// YAML writes yaml to response body.
func (output *BeegoOutput) YAML(data interface{}) error {
output.Header("Content-Type", "application/x-yaml; charset=utf-8")
var content []byte
var err error
content, err = yaml.Marshal(data)
if err != nil {
http.Error(output.Context.ResponseWriter, err.Error(), http.StatusInternalServerError)
return err
}
return output.Body(content)
}
// JSONP writes jsonp to response body.
func (output *BeegoOutput) JSONP(data interface{}, hasIndent bool) error {
output.Header("Content-Type", "application/javascript; charset=utf-8")
@ -245,6 +260,19 @@ func (output *BeegoOutput) XML(data interface{}, hasIndent bool) error {
return output.Body(content)
}
// ServeFormatted serve YAML, XML OR JSON, depending on the value of the Accept header
func (output *BeegoOutput) ServeFormatted(data interface{}, hasIndent bool, hasEncode ...bool) {
accept := output.Context.Input.Header("Accept")
switch accept {
case ApplicationYAML:
output.YAML(data)
case ApplicationXML, TextXML:
output.XML(data, hasIndent)
default:
output.JSON(data, hasIndent, len(hasEncode) > 0 && hasEncode[0])
}
}
// Download forces response for download file.
// it prepares the download response header automatically.
func (output *BeegoOutput) Download(file string, filename ...string) {
@ -260,7 +288,20 @@ func (output *BeegoOutput) Download(file string, filename ...string) {
} else {
fName = filepath.Base(file)
}
output.Header("Content-Disposition", "attachment; filename="+url.QueryEscape(fName))
//https://tools.ietf.org/html/rfc6266#section-4.3
fn := url.PathEscape(fName)
if fName == fn {
fn = "filename=" + fn
} else {
/**
The parameters "filename" and "filename*" differ only in that
"filename*" uses the encoding defined in [RFC5987], allowing the use
of characters not present in the ISO-8859-1 character set
([ISO-8859-1]).
*/
fn = "filename=" + fName + "; filename*=utf-8''" + fn
}
output.Header("Content-Disposition", "attachment; "+fn)
output.Header("Content-Description", "File Transfer")
output.Header("Content-Type", "application/octet-stream")
output.Header("Content-Transfer-Encoding", "binary")
@ -325,13 +366,13 @@ func (output *BeegoOutput) IsForbidden() bool {
}
// IsNotFound returns boolean of this request is not found.
// HTTP 404 means forbidden.
// HTTP 404 means not found.
func (output *BeegoOutput) IsNotFound() bool {
return output.Status == 404
}
// IsClientError returns boolean of this request client sends error data.
// HTTP 4xx means forbidden.
// HTTP 4xx means client error.
func (output *BeegoOutput) IsClientError() bool {
return output.Status >= 400 && output.Status < 500
}
@ -350,6 +391,11 @@ func stringsToJSON(str string) string {
jsons.WriteRune(r)
} else {
jsons.WriteString("\\u")
if rint < 0x100 {
jsons.WriteString("00")
} else if rint < 0x1000 {
jsons.WriteString("0")
}
jsons.WriteString(strconv.FormatInt(int64(rint), 16))
}
}

View File

@ -7,7 +7,7 @@ import (
// MethodParamOption defines a func which apply options on a MethodParam
type MethodParamOption func(*MethodParam)
// IsRequired indicates that this param is required and can not be ommited from the http request
// IsRequired indicates that this param is required and can not be omitted from the http request
var IsRequired MethodParamOption = func(p *MethodParam) {
p.required = true
}

View File

@ -17,6 +17,7 @@ package beego
import (
"bytes"
"errors"
"fmt"
"html/template"
"io"
"mime/multipart"
@ -32,24 +33,44 @@ import (
"github.com/astaxie/beego/session"
)
//commonly used mime-types
const (
applicationJSON = "application/json"
applicationXML = "application/xml"
textXML = "text/xml"
)
var (
// ErrAbort custom error when user stop request handler manually.
ErrAbort = errors.New("User stop run")
ErrAbort = errors.New("user stop run")
// GlobalControllerRouter store comments with controller. pkgpath+controller:comments
GlobalControllerRouter = make(map[string][]ControllerComments)
)
// ControllerFilter store the filter for controller
type ControllerFilter struct {
Pattern string
Pos int
Filter FilterFunc
ReturnOnOutput bool
ResetParams bool
}
// ControllerFilterComments store the comment for controller level filter
type ControllerFilterComments struct {
Pattern string
Pos int
Filter string // NOQA
ReturnOnOutput bool
ResetParams bool
}
// ControllerImportComments store the import comment for controller needed
type ControllerImportComments struct {
ImportPath string
ImportAlias string
}
// ControllerComments store the comment for the controller method
type ControllerComments struct {
Method string
Router string
Filters []*ControllerFilter
ImportComments []*ControllerImportComments
FilterComments []*ControllerFilterComments
AllowHTTPMethods []string
Params []map[string]string
MethodParams []*param.MethodParam
@ -73,7 +94,6 @@ type Controller struct {
controllerName string
actionName string
methodMapping map[string]func() //method:routertree
gotofunc string
AppController interface{}
// template data
@ -105,6 +125,7 @@ type ControllerInterface interface {
Head()
Patch()
Options()
Trace()
Finish()
Render() error
XSRFToken() string
@ -136,37 +157,59 @@ func (c *Controller) Finish() {}
// Get adds a request function to handle GET request.
func (c *Controller) Get() {
http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", 405)
http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed)
}
// Post adds a request function to handle POST request.
func (c *Controller) Post() {
http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", 405)
http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed)
}
// Delete adds a request function to handle DELETE request.
func (c *Controller) Delete() {
http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", 405)
http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed)
}
// Put adds a request function to handle PUT request.
func (c *Controller) Put() {
http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", 405)
http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed)
}
// Head adds a request function to handle HEAD request.
func (c *Controller) Head() {
http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", 405)
http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed)
}
// Patch adds a request function to handle PATCH request.
func (c *Controller) Patch() {
http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", 405)
http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed)
}
// Options adds a request function to handle OPTIONS request.
func (c *Controller) Options() {
http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", 405)
http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed)
}
// Trace adds a request function to handle Trace request.
// this method SHOULD NOT be overridden.
// https://tools.ietf.org/html/rfc7231#section-4.3.8
// The TRACE method requests a remote, application-level loop-back of
// the request message. The final recipient of the request SHOULD
// reflect the message received, excluding some fields described below,
// back to the client as the message body of a 200 (OK) response with a
// Content-Type of "message/http" (Section 8.3.1 of [RFC7230]).
func (c *Controller) Trace() {
ts := func(h http.Header) (hs string) {
for k, v := range h {
hs += fmt.Sprintf("\r\n%s: %s", k, v)
}
return
}
hs := fmt.Sprintf("\r\nTRACE %s %s%s\r\n", c.Ctx.Request.RequestURI, c.Ctx.Request.Proto, ts(c.Ctx.Request.Header))
c.Ctx.Output.Header("Content-Type", "message/http")
c.Ctx.Output.Header("Content-Length", fmt.Sprint(len(hs)))
c.Ctx.Output.Header("Cache-Control", "no-cache, no-store, must-revalidate")
c.Ctx.WriteString(hs)
}
// HandlerFunc call function with the name
@ -272,9 +315,23 @@ func (c *Controller) viewPath() string {
// Redirect sends the redirection response to url with status code.
func (c *Controller) Redirect(url string, code int) {
LogAccess(c.Ctx, nil, code)
c.Ctx.Redirect(code, url)
}
// SetData set the data depending on the accepted
func (c *Controller) SetData(data interface{}) {
accept := c.Ctx.Input.Header("Accept")
switch accept {
case context.ApplicationYAML:
c.Data["yaml"] = data
case context.ApplicationXML, context.TextXML:
c.Data["xml"] = data
default:
c.Data["json"] = data
}
}
// Abort stops controller handler and show the error data if code is defined in ErrorMap or code string.
func (c *Controller) Abort(code string) {
status, err := strconv.Atoi(code)
@ -317,47 +374,35 @@ func (c *Controller) URLFor(endpoint string, values ...interface{}) string {
// ServeJSON sends a json response with encoding charset.
func (c *Controller) ServeJSON(encoding ...bool) {
var (
hasIndent = true
hasEncoding = false
hasIndent = BConfig.RunMode != PROD
hasEncoding = len(encoding) > 0 && encoding[0]
)
if BConfig.RunMode == PROD {
hasIndent = false
}
if len(encoding) > 0 && encoding[0] {
hasEncoding = true
}
c.Ctx.Output.JSON(c.Data["json"], hasIndent, hasEncoding)
}
// ServeJSONP sends a jsonp response.
func (c *Controller) ServeJSONP() {
hasIndent := true
if BConfig.RunMode == PROD {
hasIndent = false
}
hasIndent := BConfig.RunMode != PROD
c.Ctx.Output.JSONP(c.Data["jsonp"], hasIndent)
}
// ServeXML sends xml response.
func (c *Controller) ServeXML() {
hasIndent := true
if BConfig.RunMode == PROD {
hasIndent = false
}
hasIndent := BConfig.RunMode != PROD
c.Ctx.Output.XML(c.Data["xml"], hasIndent)
}
// ServeFormatted serve Xml OR Json, depending on the value of the Accept header
func (c *Controller) ServeFormatted() {
accept := c.Ctx.Input.Header("Accept")
switch accept {
case applicationJSON:
c.ServeJSON()
case applicationXML, textXML:
c.ServeXML()
default:
c.ServeJSON()
}
// ServeYAML sends yaml response.
func (c *Controller) ServeYAML() {
c.Ctx.Output.YAML(c.Data["yaml"])
}
// ServeFormatted serve YAML, XML OR JSON, depending on the value of the Accept header
func (c *Controller) ServeFormatted(encoding ...bool) {
hasIndent := BConfig.RunMode != PROD
hasEncoding := len(encoding) > 0 && encoding[0]
c.Ctx.Output.ServeFormatted(c.Data, hasIndent, hasEncoding)
}
// Input returns the input data map from POST or PUT request body and query string.

View File

@ -28,7 +28,7 @@ import (
)
const (
errorTypeHandler = iota
errorTypeHandler = iota
errorTypeController
)
@ -93,11 +93,6 @@ func showErr(err interface{}, ctx *context.Context, stack string) {
"BeegoVersion": VERSION,
"GoVersion": runtime.Version(),
}
if ctx.Output.Status != 0 {
ctx.ResponseWriter.WriteHeader(ctx.Output.Status)
} else {
ctx.ResponseWriter.WriteHeader(500)
}
t.Execute(ctx.ResponseWriter, data)
}
@ -366,7 +361,7 @@ func gatewayTimeout(rw http.ResponseWriter, r *http.Request) {
func responseError(rw http.ResponseWriter, r *http.Request, errCode int, errContent string) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := map[string]interface{}{
data := M{
"Title": http.StatusText(errCode),
"BeegoVersion": VERSION,
"Content": template.HTML(errContent),
@ -439,6 +434,9 @@ func exception(errCode string, ctx *context.Context) {
}
func executeError(err *errorInfo, ctx *context.Context, code int) {
//make sure to log the error in the access log
LogAccess(ctx, nil, code)
if err.errorType == errorTypeHandler {
ctx.ResponseWriter.WriteHeader(code)
err.handler(ctx.ResponseWriter, ctx.Request)

74
src/vendor/github.com/astaxie/beego/fs.go generated vendored Normal file
View File

@ -0,0 +1,74 @@
package beego
import (
"net/http"
"os"
"path/filepath"
)
type FileSystem struct {
}
func (d FileSystem) Open(name string) (http.File, error) {
return os.Open(name)
}
// Walk walks the file tree rooted at root in filesystem, calling walkFn for each file or
// directory in the tree, including root. All errors that arise visiting files
// and directories are filtered by walkFn.
func Walk(fs http.FileSystem, root string, walkFn filepath.WalkFunc) error {
f, err := fs.Open(root)
if err != nil {
return err
}
info, err := f.Stat()
if err != nil {
err = walkFn(root, nil, err)
} else {
err = walk(fs, root, info, walkFn)
}
if err == filepath.SkipDir {
return nil
}
return err
}
// walk recursively descends path, calling walkFn.
func walk(fs http.FileSystem, path string, info os.FileInfo, walkFn filepath.WalkFunc) error {
var err error
if !info.IsDir() {
return walkFn(path, info, nil)
}
dir, err := fs.Open(path)
if err != nil {
if err1 := walkFn(path, info, err); err1 != nil {
return err1
}
return err
}
defer dir.Close()
dirs, err := dir.Readdir(-1)
err1 := walkFn(path, info, err)
// If err != nil, walk can't walk into this directory.
// err1 != nil means walkFn want walk to skip this directory or stop walking.
// Therefore, if one of err and err1 isn't nil, walk will return.
if err != nil || err1 != nil {
// The caller's behavior is controlled by the return value, which is decided
// by walkFn. walkFn may ignore err and return nil.
// If walkFn returns SkipDir, it will be handled by the caller.
// So walk should return whatever walkFn returns.
return err1
}
for _, fileInfo := range dirs {
filename := filepath.Join(path, fileInfo.Name())
if err = walk(fs, filename, fileInfo, walkFn); err != nil {
if !fileInfo.IsDir() || err != filepath.SkipDir {
return err
}
}
}
return nil
}

39
src/vendor/github.com/astaxie/beego/go.mod generated vendored Normal file
View File

@ -0,0 +1,39 @@
module github.com/astaxie/beego
require (
github.com/Knetic/govaluate v3.0.0+incompatible // indirect
github.com/OwnLocal/goes v1.0.0
github.com/beego/goyaml2 v0.0.0-20130207012346-5545475820dd
github.com/beego/x2j v0.0.0-20131220205130-a0352aadc542
github.com/bradfitz/gomemcache v0.0.0-20180710155616-bc664df96737
github.com/casbin/casbin v1.7.0
github.com/cloudflare/golz4 v0.0.0-20150217214814-ef862a3cdc58
github.com/couchbase/go-couchbase v0.0.0-20181122212707-3e9b6e1258bb
github.com/couchbase/gomemcached v0.0.0-20181122193126-5125a94a666c // indirect
github.com/couchbase/goutils v0.0.0-20180530154633-e865a1461c8a // indirect
github.com/cupcake/rdb v0.0.0-20161107195141-43ba34106c76 // indirect
github.com/edsrzf/mmap-go v0.0.0-20170320065105-0bce6a688712 // indirect
github.com/elazarl/go-bindata-assetfs v1.0.0
github.com/go-redis/redis v6.14.2+incompatible
github.com/go-sql-driver/mysql v1.4.1
github.com/gogo/protobuf v1.1.1
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect
github.com/gomodule/redigo v2.0.0+incompatible
github.com/lib/pq v1.0.0
github.com/mattn/go-sqlite3 v1.10.0
github.com/pelletier/go-toml v1.2.0 // indirect
github.com/pkg/errors v0.8.0 // indirect
github.com/siddontang/go v0.0.0-20180604090527-bdc77568d726 // indirect
github.com/siddontang/ledisdb v0.0.0-20181029004158-becf5f38d373
github.com/siddontang/rdb v0.0.0-20150307021120-fc89ed2e418d // indirect
github.com/ssdb/gossdb v0.0.0-20180723034631-88f6b59b84ec
github.com/syndtr/goleveldb v0.0.0-20181127023241-353a9fca669c // indirect
github.com/wendal/errors v0.0.0-20130201093226-f66c77a7882b // indirect
golang.org/x/crypto v0.0.0-20181127143415-eb0de9b17e85
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a // indirect
gopkg.in/yaml.v2 v2.2.1
)
replace golang.org/x/crypto v0.0.0-20181127143415-eb0de9b17e85 => github.com/golang/crypto v0.0.0-20181127143415-eb0de9b17e85
replace gopkg.in/yaml.v2 v2.2.1 => github.com/go-yaml/yaml v0.0.0-20180328195020-5420a8b6744d

68
src/vendor/github.com/astaxie/beego/go.sum generated vendored Normal file
View File

@ -0,0 +1,68 @@
github.com/Knetic/govaluate v3.0.0+incompatible h1:7o6+MAPhYTCF0+fdvoz1xDedhRb4f6s9Tn1Tt7/WTEg=
github.com/Knetic/govaluate v3.0.0+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0=
github.com/OwnLocal/goes v1.0.0/go.mod h1:8rIFjBGTue3lCU0wplczcUgt9Gxgrkkrw7etMIcn8TM=
github.com/beego/goyaml2 v0.0.0-20130207012346-5545475820dd h1:jZtX5jh5IOMu0fpOTC3ayh6QGSPJ/KWOv1lgPvbRw1M=
github.com/beego/goyaml2 v0.0.0-20130207012346-5545475820dd/go.mod h1:1b+Y/CofkYwXMUU0OhQqGvsY2Bvgr4j6jfT699wyZKQ=
github.com/beego/x2j v0.0.0-20131220205130-a0352aadc542 h1:nYXb+3jF6Oq/j8R/y90XrKpreCxIalBWfeyeKymgOPk=
github.com/beego/x2j v0.0.0-20131220205130-a0352aadc542/go.mod h1:kSeGC/p1AbBiEp5kat81+DSQrZenVBZXklMLaELspWU=
github.com/belogik/goes v0.0.0-20151229125003-e54d722c3aff h1:/kO0p2RTGLB8R5gub7ps0GmYpB2O8LXEoPq8tzFDCUI=
github.com/belogik/goes v0.0.0-20151229125003-e54d722c3aff/go.mod h1:PhH1ZhyCzHKt4uAasyx+ljRCgoezetRNf59CUtwUkqY=
github.com/bradfitz/gomemcache v0.0.0-20180710155616-bc664df96737 h1:rRISKWyXfVxvoa702s91Zl5oREZTrR3yv+tXrrX7G/g=
github.com/bradfitz/gomemcache v0.0.0-20180710155616-bc664df96737/go.mod h1:PmM6Mmwb0LSuEubjR8N7PtNe1KxZLtOUHtbeikc5h60=
github.com/casbin/casbin v1.7.0 h1:PuzlE8w0JBg/DhIqnkF1Dewf3z+qmUZMVN07PonvVUQ=
github.com/casbin/casbin v1.7.0/go.mod h1:c67qKN6Oum3UF5Q1+BByfFxkwKvhwW57ITjqwtzR1KE=
github.com/cloudflare/golz4 v0.0.0-20150217214814-ef862a3cdc58 h1:F1EaeKL/ta07PY/k9Os/UFtwERei2/XzGemhpGnBKNg=
github.com/cloudflare/golz4 v0.0.0-20150217214814-ef862a3cdc58/go.mod h1:EOBUe0h4xcZ5GoxqC5SDxFQ8gwyZPKQoEzownBlhI80=
github.com/couchbase/go-couchbase v0.0.0-20181122212707-3e9b6e1258bb h1:w3RapLhkA5+km9Z8vUkC6VCaskduJXvXwJg5neKnfDU=
github.com/couchbase/go-couchbase v0.0.0-20181122212707-3e9b6e1258bb/go.mod h1:TWI8EKQMs5u5jLKW/tsb9VwauIrMIxQG1r5fMsswK5U=
github.com/couchbase/gomemcached v0.0.0-20181122193126-5125a94a666c h1:K4FIibkr4//ziZKOKmt4RL0YImuTjLLBtwElf+F2lSQ=
github.com/couchbase/gomemcached v0.0.0-20181122193126-5125a94a666c/go.mod h1:srVSlQLB8iXBVXHgnqemxUXqN6FCvClgCMPCsjBDR7c=
github.com/couchbase/goutils v0.0.0-20180530154633-e865a1461c8a h1:Y5XsLCEhtEI8qbD9RP3Qlv5FXdTDHxZM9UPUnMRgBp8=
github.com/couchbase/goutils v0.0.0-20180530154633-e865a1461c8a/go.mod h1:BQwMFlJzDjFDG3DJUdU0KORxn88UlsOULuxLExMh3Hs=
github.com/cupcake/rdb v0.0.0-20161107195141-43ba34106c76 h1:Lgdd/Qp96Qj8jqLpq2cI1I1X7BJnu06efS+XkhRoLUQ=
github.com/cupcake/rdb v0.0.0-20161107195141-43ba34106c76/go.mod h1:vYwsqCOLxGiisLwp9rITslkFNpZD5rz43tf41QFkTWY=
github.com/edsrzf/mmap-go v0.0.0-20170320065105-0bce6a688712 h1:aaQcKT9WumO6JEJcRyTqFVq4XUZiUcKR2/GI31TOcz8=
github.com/edsrzf/mmap-go v0.0.0-20170320065105-0bce6a688712/go.mod h1:YO35OhQPt3KJa3ryjFM5Bs14WD66h8eGKpfaBNrHW5M=
github.com/elazarl/go-bindata-assetfs v1.0.0 h1:G/bYguwHIzWq9ZoyUQqrjTmJbbYn3j3CKKpKinvZLFk=
github.com/elazarl/go-bindata-assetfs v1.0.0/go.mod h1:v+YaWX3bdea5J/mo8dSETolEo7R71Vk1u8bnjau5yw4=
github.com/go-redis/redis v6.14.2+incompatible h1:UE9pLhzmWf+xHNmZsoccjXosPicuiNaInPgym8nzfg0=
github.com/go-redis/redis v6.14.2+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA=
github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA=
github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w=
github.com/go-yaml/yaml v0.0.0-20180328195020-5420a8b6744d h1:xy93KVe+KrIIwWDEAfQBdIfsiHJkepbYsDr+VY3g9/o=
github.com/go-yaml/yaml v0.0.0-20180328195020-5420a8b6744d/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
github.com/gogo/protobuf v1.1.1 h1:72R+M5VuhED/KujmZVcIquuo8mBgX4oVda//DQb3PXo=
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/golang/crypto v0.0.0-20181127143415-eb0de9b17e85 h1:B7ZbAFz7NOmvpUE5RGtu3u0WIizy5GdvbNpEf4RPnWs=
github.com/golang/crypto v0.0.0-20181127143415-eb0de9b17e85/go.mod h1:uZvAcrsnNaCxlh1HorK5dUQHGmEKPh2H/Rl1kehswPo=
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db h1:woRePGFeVFfLKN/pOkfl+p/TAqKOfFu+7KPlMVpok/w=
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/gomodule/redigo v2.0.0+incompatible h1:K/R+8tc58AaqLkqG2Ol3Qk+DR/TlNuhuh457pBFPtt0=
github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4=
github.com/lib/pq v1.0.0 h1:X5PMW56eZitiTeO7tKzZxFCSpbFZJtkMMooicw2us9A=
github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/mattn/go-sqlite3 v1.10.0 h1:jbhqpg7tQe4SupckyijYiy0mJJ/pRyHvXf7JdWK860o=
github.com/mattn/go-sqlite3 v1.10.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/pelletier/go-toml v1.2.0 h1:T5zMGML61Wp+FlcbWjRDT7yAxhJNAiPPLOFECq181zc=
github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic=
github.com/pkg/errors v0.8.0 h1:WdK/asTD0HN+q6hsWO3/vpuAkAr+tw6aNJNDFFf0+qw=
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/siddontang/go v0.0.0-20180604090527-bdc77568d726 h1:xT+JlYxNGqyT+XcU8iUrN18JYed2TvG9yN5ULG2jATM=
github.com/siddontang/go v0.0.0-20180604090527-bdc77568d726/go.mod h1:3yhqj7WBBfRhbBlzyOC3gUxftwsU0u8gqevxwIHQpMw=
github.com/siddontang/ledisdb v0.0.0-20181029004158-becf5f38d373 h1:p6IxqQMjab30l4lb9mmkIkkcE1yv6o0SKbPhW5pxqHI=
github.com/siddontang/ledisdb v0.0.0-20181029004158-becf5f38d373/go.mod h1:mF1DpOSOUiJRMR+FDqaqu3EBqrybQtrDDszLUZ6oxPg=
github.com/siddontang/rdb v0.0.0-20150307021120-fc89ed2e418d h1:NVwnfyR3rENtlz62bcrkXME3INVUa4lcdGt+opvxExs=
github.com/siddontang/rdb v0.0.0-20150307021120-fc89ed2e418d/go.mod h1:AMEsy7v5z92TR1JKMkLLoaOQk++LVnOKL3ScbJ8GNGA=
github.com/ssdb/gossdb v0.0.0-20180723034631-88f6b59b84ec h1:q6XVwXmKvCRHRqesF3cSv6lNqqHi0QWOvgDlSohg8UA=
github.com/ssdb/gossdb v0.0.0-20180723034631-88f6b59b84ec/go.mod h1:QBvMkMya+gXctz3kmljlUCu/yB3GZ6oee+dUozsezQE=
github.com/syndtr/goleveldb v0.0.0-20181127023241-353a9fca669c h1:3eGShk3EQf5gJCYW+WzA0TEJQd37HLOmlYF7N0YJwv0=
github.com/syndtr/goleveldb v0.0.0-20181127023241-353a9fca669c/go.mod h1:Z4AUp2Km+PwemOoO/VB5AOx9XSsIItzFjoJlOSiYmn0=
github.com/wendal/errors v0.0.0-20130201093226-f66c77a7882b h1:0Ve0/CCjiAiyKddUMUn3RwIGlq2iTW4GuVzyoKBYO/8=
github.com/wendal/errors v0.0.0-20130201093226-f66c77a7882b/go.mod h1:Q12BUT7DqIlHRmgv3RskH+UCM/4eqVMgI0EMmlSpAXc=
golang.org/x/crypto v0.0.0-20181127143415-eb0de9b17e85 h1:et7+NAX3lLIk5qUCTA9QelBjGE/NkhzYw/mhnr0s7nI=
golang.org/x/crypto v0.0.0-20181127143415-eb0de9b17e85/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a h1:gOpx8G595UYyvj8UK4+OFyY4rx037g3fmfhe5SasG3U=
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.1 h1:mUhvW9EsL+naU5Q3cakzfE91YhliOondGd6ZrsDBHQE=
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=

View File

@ -1,39 +0,0 @@
package grace
import (
"errors"
"net"
"sync"
)
type graceConn struct {
net.Conn
server *Server
m sync.Mutex
closed bool
}
func (c *graceConn) Close() (err error) {
defer func() {
if r := recover(); r != nil {
switch x := r.(type) {
case string:
err = errors.New(x)
case error:
err = x
default:
err = errors.New("Unknown panic")
}
}
}()
c.m.Lock()
if c.closed {
c.m.Unlock()
return
}
c.server.wg.Done()
c.closed = true
c.m.Unlock()
return c.Conn.Close()
}

View File

@ -78,7 +78,7 @@ var (
DefaultReadTimeOut time.Duration
// DefaultWriteTimeOut is the HTTP Write timeout
DefaultWriteTimeOut time.Duration
// DefaultMaxHeaderBytes is the Max HTTP Herder size, default is 0, no limit
// DefaultMaxHeaderBytes is the Max HTTP Header size, default is 0, no limit
DefaultMaxHeaderBytes int
// DefaultTimeout is the shutdown server's timeout. default is 60s
DefaultTimeout = 60 * time.Second
@ -122,7 +122,6 @@ func NewServer(addr string, handler http.Handler) (srv *Server) {
}
srv = &Server{
wg: sync.WaitGroup{},
sigChan: make(chan os.Signal),
isChild: isChild,
SignalHooks: map[int]map[os.Signal][]func(){
@ -137,20 +136,21 @@ func NewServer(addr string, handler http.Handler) (srv *Server) {
syscall.SIGTERM: {},
},
},
state: StateInit,
Network: "tcp",
state: StateInit,
Network: "tcp",
terminalChan: make(chan error), //no cache channel
}
srv.Server = &http.Server{
Addr: addr,
ReadTimeout: DefaultReadTimeOut,
WriteTimeout: DefaultWriteTimeOut,
MaxHeaderBytes: DefaultMaxHeaderBytes,
Handler: handler,
}
srv.Server = &http.Server{}
srv.Server.Addr = addr
srv.Server.ReadTimeout = DefaultReadTimeOut
srv.Server.WriteTimeout = DefaultWriteTimeOut
srv.Server.MaxHeaderBytes = DefaultMaxHeaderBytes
srv.Server.Handler = handler
runningServersOrder = append(runningServersOrder, addr)
runningServers[addr] = srv
return
return srv
}
// ListenAndServe refer http.ListenAndServe

View File

@ -1,62 +0,0 @@
package grace
import (
"net"
"os"
"syscall"
"time"
)
type graceListener struct {
net.Listener
stop chan error
stopped bool
server *Server
}
func newGraceListener(l net.Listener, srv *Server) (el *graceListener) {
el = &graceListener{
Listener: l,
stop: make(chan error),
server: srv,
}
go func() {
<-el.stop
el.stopped = true
el.stop <- el.Listener.Close()
}()
return
}
func (gl *graceListener) Accept() (c net.Conn, err error) {
tc, err := gl.Listener.(*net.TCPListener).AcceptTCP()
if err != nil {
return
}
tc.SetKeepAlive(true)
tc.SetKeepAlivePeriod(3 * time.Minute)
c = &graceConn{
Conn: tc,
server: gl.server,
}
gl.server.wg.Add(1)
return
}
func (gl *graceListener) Close() error {
if gl.stopped {
return syscall.EINVAL
}
gl.stop <- nil
return <-gl.stop
}
func (gl *graceListener) File() *os.File {
// returns a dup(2) - FD_CLOEXEC flag *not* set
tl := gl.Listener.(*net.TCPListener)
fl, _ := tl.File()
return fl
}

View File

@ -1,8 +1,11 @@
package grace
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"log"
"net"
"net/http"
@ -10,7 +13,6 @@ import (
"os/exec"
"os/signal"
"strings"
"sync"
"syscall"
"time"
)
@ -18,14 +20,13 @@ import (
// Server embedded http.Server
type Server struct {
*http.Server
GraceListener net.Listener
SignalHooks map[int]map[os.Signal][]func()
tlsInnerListener *graceListener
wg sync.WaitGroup
sigChan chan os.Signal
isChild bool
state uint8
Network string
ln net.Listener
SignalHooks map[int]map[os.Signal][]func()
sigChan chan os.Signal
isChild bool
state uint8
Network string
terminalChan chan error
}
// Serve accepts incoming connections on the Listener l,
@ -33,11 +34,19 @@ type Server struct {
// The service goroutines read requests and then call srv.Handler to reply to them.
func (srv *Server) Serve() (err error) {
srv.state = StateRunning
err = srv.Server.Serve(srv.GraceListener)
log.Println(syscall.Getpid(), "Waiting for connections to finish...")
srv.wg.Wait()
srv.state = StateTerminate
return
defer func() { srv.state = StateTerminate }()
// When Shutdown is called, Serve, ListenAndServe, and ListenAndServeTLS
// immediately return ErrServerClosed. Make sure the program doesn't exit
// and waits instead for Shutdown to return.
if err = srv.Server.Serve(srv.ln); err != nil && err != http.ErrServerClosed {
log.Println(syscall.Getpid(), "Server.Serve() error:", err)
return err
}
log.Println(syscall.Getpid(), srv.ln.Addr(), "Listener closed.")
// wait for Shutdown to return
return <-srv.terminalChan
}
// ListenAndServe listens on the TCP network address srv.Addr and then calls Serve
@ -51,21 +60,19 @@ func (srv *Server) ListenAndServe() (err error) {
go srv.handleSignals()
l, err := srv.getListener(addr)
srv.ln, err = srv.getListener(addr)
if err != nil {
log.Println(err)
return err
}
srv.GraceListener = newGraceListener(l, srv)
if srv.isChild {
process, err := os.FindProcess(os.Getppid())
if err != nil {
log.Println(err)
return err
}
err = process.Kill()
err = process.Signal(syscall.SIGTERM)
if err != nil {
return err
}
@ -105,14 +112,67 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) {
go srv.handleSignals()
l, err := srv.getListener(addr)
ln, err := srv.getListener(addr)
if err != nil {
log.Println(err)
return err
}
srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig)
srv.tlsInnerListener = newGraceListener(l, srv)
srv.GraceListener = tls.NewListener(srv.tlsInnerListener, srv.TLSConfig)
if srv.isChild {
process, err := os.FindProcess(os.Getppid())
if err != nil {
log.Println(err)
return err
}
err = process.Signal(syscall.SIGTERM)
if err != nil {
return err
}
}
log.Println(os.Getpid(), srv.Addr)
return srv.Serve()
}
// ListenAndServeMutualTLS listens on the TCP network address srv.Addr and then calls
// Serve to handle requests on incoming mutual TLS connections.
func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) (err error) {
addr := srv.Addr
if addr == "" {
addr = ":https"
}
if srv.TLSConfig == nil {
srv.TLSConfig = &tls.Config{}
}
if srv.TLSConfig.NextProtos == nil {
srv.TLSConfig.NextProtos = []string{"http/1.1"}
}
srv.TLSConfig.Certificates = make([]tls.Certificate, 1)
srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return
}
srv.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert
pool := x509.NewCertPool()
data, err := ioutil.ReadFile(trustFile)
if err != nil {
log.Println(err)
return err
}
pool.AppendCertsFromPEM(data)
srv.TLSConfig.ClientCAs = pool
log.Println("Mutual HTTPS")
go srv.handleSignals()
ln, err := srv.getListener(addr)
if err != nil {
log.Println(err)
return err
}
srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig)
if srv.isChild {
process, err := os.FindProcess(os.Getppid())
@ -125,6 +185,7 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) {
return err
}
}
log.Println(os.Getpid(), srv.Addr)
return srv.Serve()
}
@ -155,6 +216,20 @@ func (srv *Server) getListener(laddr string) (l net.Listener, err error) {
return
}
type tcpKeepAliveListener struct {
*net.TCPListener
}
func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
tc, err := ln.AcceptTCP()
if err != nil {
return
}
tc.SetKeepAlive(true)
tc.SetKeepAlivePeriod(3 * time.Minute)
return tc, nil
}
// handleSignals listens for os Signals and calls any hooked in function that the
// user had registered with the signal.
func (srv *Server) handleSignals() {
@ -207,37 +282,14 @@ func (srv *Server) shutdown() {
}
srv.state = StateShuttingDown
log.Println(syscall.Getpid(), "Waiting for connections to finish...")
ctx := context.Background()
if DefaultTimeout >= 0 {
go srv.serverTimeout(DefaultTimeout)
}
err := srv.GraceListener.Close()
if err != nil {
log.Println(syscall.Getpid(), "Listener.Close() error:", err)
} else {
log.Println(syscall.Getpid(), srv.GraceListener.Addr(), "Listener closed.")
}
}
// serverTimeout forces the server to shutdown in a given timeout - whether it
// finished outstanding requests or not. if Read/WriteTimeout are not set or the
// max header size is very big a connection could hang
func (srv *Server) serverTimeout(d time.Duration) {
defer func() {
if r := recover(); r != nil {
log.Println("WaitGroup at 0", r)
}
}()
if srv.state != StateShuttingDown {
return
}
time.Sleep(d)
log.Println("[STOP - Hammer Time] Forcefully shutting down parent")
for {
if srv.state == StateTerminate {
break
}
srv.wg.Done()
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(context.Background(), DefaultTimeout)
defer cancel()
}
srv.terminalChan <- srv.Server.Shutdown(ctx)
}
func (srv *Server) fork() (err error) {
@ -251,12 +303,8 @@ func (srv *Server) fork() (err error) {
var files = make([]*os.File, len(runningServers))
var orderArgs = make([]string, len(runningServers))
for _, srvPtr := range runningServers {
switch srvPtr.GraceListener.(type) {
case *graceListener:
files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.GraceListener.(*graceListener).File()
default:
files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.tlsInnerListener.File()
}
f, _ := srvPtr.ln.(*net.TCPListener).File()
files[socketPtrOffsetMap[srvPtr.Server.Addr]] = f
orderArgs[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.Server.Addr
}

View File

@ -11,7 +11,7 @@ import (
"github.com/astaxie/beego/session"
)
//
// register MIME type with content type
func registerMime() error {
for k, v := range mimemaps {
mime.AddExtensionType(k, v)

View File

@ -21,6 +21,7 @@ import (
)
// Log levels to control the logging output.
// Deprecated: use github.com/astaxie/beego/logs instead.
const (
LevelEmergency = iota
LevelAlert
@ -33,75 +34,90 @@ const (
)
// BeeLogger references the used application logger.
// Deprecated: use github.com/astaxie/beego/logs instead.
var BeeLogger = logs.GetBeeLogger()
// SetLevel sets the global log level used by the simple logger.
// Deprecated: use github.com/astaxie/beego/logs instead.
func SetLevel(l int) {
logs.SetLevel(l)
}
// SetLogFuncCall set the CallDepth, default is 3
// Deprecated: use github.com/astaxie/beego/logs instead.
func SetLogFuncCall(b bool) {
logs.SetLogFuncCall(b)
}
// SetLogger sets a new logger.
// Deprecated: use github.com/astaxie/beego/logs instead.
func SetLogger(adaptername string, config string) error {
return logs.SetLogger(adaptername, config)
}
// Emergency logs a message at emergency level.
// Deprecated: use github.com/astaxie/beego/logs instead.
func Emergency(v ...interface{}) {
logs.Emergency(generateFmtStr(len(v)), v...)
}
// Alert logs a message at alert level.
// Deprecated: use github.com/astaxie/beego/logs instead.
func Alert(v ...interface{}) {
logs.Alert(generateFmtStr(len(v)), v...)
}
// Critical logs a message at critical level.
// Deprecated: use github.com/astaxie/beego/logs instead.
func Critical(v ...interface{}) {
logs.Critical(generateFmtStr(len(v)), v...)
}
// Error logs a message at error level.
// Deprecated: use github.com/astaxie/beego/logs instead.
func Error(v ...interface{}) {
logs.Error(generateFmtStr(len(v)), v...)
}
// Warning logs a message at warning level.
// Deprecated: use github.com/astaxie/beego/logs instead.
func Warning(v ...interface{}) {
logs.Warning(generateFmtStr(len(v)), v...)
}
// Warn compatibility alias for Warning()
// Deprecated: use github.com/astaxie/beego/logs instead.
func Warn(v ...interface{}) {
logs.Warn(generateFmtStr(len(v)), v...)
}
// Notice logs a message at notice level.
// Deprecated: use github.com/astaxie/beego/logs instead.
func Notice(v ...interface{}) {
logs.Notice(generateFmtStr(len(v)), v...)
}
// Informational logs a message at info level.
// Deprecated: use github.com/astaxie/beego/logs instead.
func Informational(v ...interface{}) {
logs.Informational(generateFmtStr(len(v)), v...)
}
// Info compatibility alias for Warning()
// Deprecated: use github.com/astaxie/beego/logs instead.
func Info(v ...interface{}) {
logs.Info(generateFmtStr(len(v)), v...)
}
// Debug logs a message at debug level.
// Deprecated: use github.com/astaxie/beego/logs instead.
func Debug(v ...interface{}) {
logs.Debug(generateFmtStr(len(v)), v...)
}
// Trace logs a message at trace level.
// compatibility alias for Warning()
// Deprecated: use github.com/astaxie/beego/logs instead.
func Trace(v ...interface{}) {
logs.Trace(generateFmtStr(len(v)), v...)
}

View File

@ -16,48 +16,57 @@ As of now this logs support console, file,smtp and conn.
First you must import it
import (
"github.com/astaxie/beego/logs"
)
```golang
import (
"github.com/astaxie/beego/logs"
)
```
Then init a Log (example with console adapter)
log := NewLogger(10000)
log.SetLogger("console", "")
```golang
log := logs.NewLogger(10000)
log.SetLogger("console", "")
```
> the first params stand for how many channel
Use it like this:
log.Trace("trace")
log.Info("info")
log.Warn("warning")
log.Debug("debug")
log.Critical("critical")
Use it like this:
```golang
log.Trace("trace")
log.Info("info")
log.Warn("warning")
log.Debug("debug")
log.Critical("critical")
```
## File adapter
Configure file adapter like this:
log := NewLogger(10000)
log.SetLogger("file", `{"filename":"test.log"}`)
```golang
log := NewLogger(10000)
log.SetLogger("file", `{"filename":"test.log"}`)
```
## Conn adapter
Configure like this:
log := NewLogger(1000)
log.SetLogger("conn", `{"net":"tcp","addr":":7020"}`)
log.Info("info")
```golang
log := NewLogger(1000)
log.SetLogger("conn", `{"net":"tcp","addr":":7020"}`)
log.Info("info")
```
## Smtp adapter
Configure like this:
log := NewLogger(10000)
log.SetLogger("smtp", `{"username":"beegotest@gmail.com","password":"xxxxxxxx","host":"smtp.gmail.com:587","sendTos":["xiemengjun@gmail.com"]}`)
log.Critical("sendmail critical")
time.Sleep(time.Second * 30)
```golang
log := NewLogger(10000)
log.SetLogger("smtp", `{"username":"beegotest@gmail.com","password":"xxxxxxxx","host":"smtp.gmail.com:587","sendTos":["xiemengjun@gmail.com"]}`)
log.Critical("sendmail critical")
time.Sleep(time.Second * 30)
```

83
src/vendor/github.com/astaxie/beego/logs/accesslog.go generated vendored Normal file
View File

@ -0,0 +1,83 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package logs
import (
"bytes"
"strings"
"encoding/json"
"fmt"
"time"
)
const (
apacheFormatPattern = "%s - - [%s] \"%s %d %d\" %f %s %s"
apacheFormat = "APACHE_FORMAT"
jsonFormat = "JSON_FORMAT"
)
// AccessLogRecord struct for holding access log data.
type AccessLogRecord struct {
RemoteAddr string `json:"remote_addr"`
RequestTime time.Time `json:"request_time"`
RequestMethod string `json:"request_method"`
Request string `json:"request"`
ServerProtocol string `json:"server_protocol"`
Host string `json:"host"`
Status int `json:"status"`
BodyBytesSent int64 `json:"body_bytes_sent"`
ElapsedTime time.Duration `json:"elapsed_time"`
HTTPReferrer string `json:"http_referrer"`
HTTPUserAgent string `json:"http_user_agent"`
RemoteUser string `json:"remote_user"`
}
func (r *AccessLogRecord) json() ([]byte, error) {
buffer := &bytes.Buffer{}
encoder := json.NewEncoder(buffer)
disableEscapeHTML(encoder)
err := encoder.Encode(r)
return buffer.Bytes(), err
}
func disableEscapeHTML(i interface{}) {
if e, ok := i.(interface {
SetEscapeHTML(bool)
}); ok {
e.SetEscapeHTML(false)
}
}
// AccessLog - Format and print access log.
func AccessLog(r *AccessLogRecord, format string) {
var msg string
switch format {
case apacheFormat:
timeFormatted := r.RequestTime.Format("02/Jan/2006 03:04:05")
msg = fmt.Sprintf(apacheFormatPattern, r.RemoteAddr, timeFormatted, r.Request, r.Status, r.BodyBytesSent,
r.ElapsedTime.Seconds(), r.HTTPReferrer, r.HTTPUserAgent)
case jsonFormat:
fallthrough
default:
jsonData, err := r.json()
if err != nil {
msg = fmt.Sprintf(`{"Error": "%s"}`, err)
} else {
msg = string(jsonData)
}
}
beeLogger.writeMsg(levelLoggerImpl, strings.TrimSpace(msg))
}

View File

@ -1,28 +0,0 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build !windows
package logs
import "io"
type ansiColorWriter struct {
w io.Writer
mode outputMode
}
func (cw *ansiColorWriter) Write(p []byte) (int, error) {
return cw.w.Write(p)
}

View File

@ -63,7 +63,7 @@ func (c *connWriter) WriteMsg(when time.Time, msg string, level int) error {
defer c.innerWriter.Close()
}
c.lg.println(when, msg)
c.lg.writeln(when, msg)
return nil
}

View File

@ -17,8 +17,10 @@ package logs
import (
"encoding/json"
"os"
"runtime"
"strings"
"time"
"github.com/shiena/ansicolor"
)
// brush is a color join function
@ -54,9 +56,9 @@ type consoleWriter struct {
// NewConsole create ConsoleWriter returning as LoggerInterface.
func NewConsole() Logger {
cw := &consoleWriter{
lg: newLogWriter(os.Stdout),
lg: newLogWriter(ansicolor.NewAnsiColorWriter(os.Stdout)),
Level: LevelDebug,
Colorful: runtime.GOOS != "windows",
Colorful: true,
}
return cw
}
@ -67,11 +69,7 @@ func (c *consoleWriter) Init(jsonConfig string) error {
if len(jsonConfig) == 0 {
return nil
}
err := json.Unmarshal([]byte(jsonConfig), c)
if runtime.GOOS == "windows" {
c.Colorful = false
}
return err
return json.Unmarshal([]byte(jsonConfig), c)
}
// WriteMsg write message in console.
@ -80,9 +78,9 @@ func (c *consoleWriter) WriteMsg(when time.Time, msg string, level int) error {
return nil
}
if c.Colorful {
msg = colors[level](msg)
msg = strings.Replace(msg, levelPrefix[level], colors[level](levelPrefix[level]), 1)
}
c.lg.println(when, msg)
c.lg.writeln(when, msg)
return nil
}

View File

@ -21,6 +21,7 @@ import (
"fmt"
"io"
"os"
"path"
"path/filepath"
"strconv"
"strings"
@ -40,6 +41,9 @@ type fileLogWriter struct {
MaxLines int `json:"maxlines"`
maxLinesCurLines int
MaxFiles int `json:"maxfiles"`
MaxFilesCurFiles int
// Rotate at size
MaxSize int `json:"maxsize"`
maxSizeCurSize int
@ -50,6 +54,12 @@ type fileLogWriter struct {
dailyOpenDate int
dailyOpenTime time.Time
// Rotate hourly
Hourly bool `json:"hourly"`
MaxHours int64 `json:"maxhours"`
hourlyOpenDate int
hourlyOpenTime time.Time
Rotate bool `json:"rotate"`
Level int `json:"level"`
@ -66,25 +76,30 @@ func newFileWriter() Logger {
w := &fileLogWriter{
Daily: true,
MaxDays: 7,
Hourly: false,
MaxHours: 168,
Rotate: true,
RotatePerm: "0440",
Level: LevelTrace,
Perm: "0660",
MaxLines: 10000000,
MaxFiles: 999,
MaxSize: 1 << 28,
}
return w
}
// Init file logger with json config.
// jsonConfig like:
// {
// "filename":"logs/beego.log",
// "maxLines":10000,
// "maxsize":1024,
// "daily":true,
// "maxDays":15,
// "rotate":true,
// "perm":"0600"
// }
// {
// "filename":"logs/beego.log",
// "maxLines":10000,
// "maxsize":1024,
// "daily":true,
// "maxDays":15,
// "rotate":true,
// "perm":"0600"
// }
func (w *fileLogWriter) Init(jsonConfig string) error {
err := json.Unmarshal([]byte(jsonConfig), w)
if err != nil {
@ -115,10 +130,16 @@ func (w *fileLogWriter) startLogger() error {
return w.initFd()
}
func (w *fileLogWriter) needRotate(size int, day int) bool {
func (w *fileLogWriter) needRotateDaily(size int, day int) bool {
return (w.MaxLines > 0 && w.maxLinesCurLines >= w.MaxLines) ||
(w.MaxSize > 0 && w.maxSizeCurSize >= w.MaxSize) ||
(w.Daily && day != w.dailyOpenDate)
}
func (w *fileLogWriter) needRotateHourly(size int, hour int) bool {
return (w.MaxLines > 0 && w.maxLinesCurLines >= w.MaxLines) ||
(w.MaxSize > 0 && w.maxSizeCurSize >= w.MaxSize) ||
(w.Hourly && hour != w.hourlyOpenDate)
}
@ -127,14 +148,23 @@ func (w *fileLogWriter) WriteMsg(when time.Time, msg string, level int) error {
if level > w.Level {
return nil
}
h, d := formatTimeHeader(when)
msg = string(h) + msg + "\n"
hd, d, h := formatTimeHeader(when)
msg = string(hd) + msg + "\n"
if w.Rotate {
w.RLock()
if w.needRotate(len(msg), d) {
if w.needRotateHourly(len(msg), h) {
w.RUnlock()
w.Lock()
if w.needRotate(len(msg), d) {
if w.needRotateHourly(len(msg), h) {
if err := w.doRotate(when); err != nil {
fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err)
}
}
w.Unlock()
} else if w.needRotateDaily(len(msg), d) {
w.RUnlock()
w.Lock()
if w.needRotateDaily(len(msg), d) {
if err := w.doRotate(when); err != nil {
fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err)
}
@ -161,6 +191,10 @@ func (w *fileLogWriter) createLogFile() (*os.File, error) {
if err != nil {
return nil, err
}
filepath := path.Dir(w.Filename)
os.MkdirAll(filepath, os.FileMode(perm))
fd, err := os.OpenFile(w.Filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, os.FileMode(perm))
if err == nil {
// Make sure file perm is user set perm cause of `os.OpenFile` will obey umask
@ -178,11 +212,15 @@ func (w *fileLogWriter) initFd() error {
w.maxSizeCurSize = int(fInfo.Size())
w.dailyOpenTime = time.Now()
w.dailyOpenDate = w.dailyOpenTime.Day()
w.hourlyOpenTime = time.Now()
w.hourlyOpenDate = w.hourlyOpenTime.Hour()
w.maxLinesCurLines = 0
if w.Daily {
if w.Hourly {
go w.hourlyRotate(w.hourlyOpenTime)
} else if w.Daily {
go w.dailyRotate(w.dailyOpenTime)
}
if fInfo.Size() > 0 {
if fInfo.Size() > 0 && w.MaxLines > 0 {
count, err := w.lines()
if err != nil {
return err
@ -198,7 +236,22 @@ func (w *fileLogWriter) dailyRotate(openTime time.Time) {
tm := time.NewTimer(time.Duration(nextDay.UnixNano() - openTime.UnixNano() + 100))
<-tm.C
w.Lock()
if w.needRotate(0, time.Now().Day()) {
if w.needRotateDaily(0, time.Now().Day()) {
if err := w.doRotate(time.Now()); err != nil {
fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err)
}
}
w.Unlock()
}
func (w *fileLogWriter) hourlyRotate(openTime time.Time) {
y, m, d := openTime.Add(1 * time.Hour).Date()
h, _, _ := openTime.Add(1 * time.Hour).Clock()
nextHour := time.Date(y, m, d, h, 0, 0, 0, openTime.Location())
tm := time.NewTimer(time.Duration(nextHour.UnixNano() - openTime.UnixNano() + 100))
<-tm.C
w.Lock()
if w.needRotateHourly(0, time.Now().Hour()) {
if err := w.doRotate(time.Now()); err != nil {
fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err)
}
@ -238,8 +291,10 @@ func (w *fileLogWriter) lines() (int, error) {
func (w *fileLogWriter) doRotate(logTime time.Time) error {
// file exists
// Find the next available number
num := 1
num := w.MaxFilesCurFiles + 1
fName := ""
format := ""
var openTime time.Time
rotatePerm, err := strconv.ParseInt(w.RotatePerm, 8, 64)
if err != nil {
return err
@ -251,19 +306,26 @@ func (w *fileLogWriter) doRotate(logTime time.Time) error {
goto RESTART_LOGGER
}
if w.Hourly {
format = "2006010215"
openTime = w.hourlyOpenTime
} else if w.Daily {
format = "2006-01-02"
openTime = w.dailyOpenTime
}
// only when one of them be setted, then the file would be splited
if w.MaxLines > 0 || w.MaxSize > 0 {
for ; err == nil && num <= 999; num++ {
fName = w.fileNameOnly + fmt.Sprintf(".%s.%03d%s", logTime.Format("2006-01-02"), num, w.suffix)
for ; err == nil && num <= w.MaxFiles; num++ {
fName = w.fileNameOnly + fmt.Sprintf(".%s.%03d%s", logTime.Format(format), num, w.suffix)
_, err = os.Lstat(fName)
}
} else {
fName = fmt.Sprintf("%s.%s%s", w.fileNameOnly, w.dailyOpenTime.Format("2006-01-02"), w.suffix)
fName = w.fileNameOnly + fmt.Sprintf(".%s.%03d%s", openTime.Format(format), num, w.suffix)
_, err = os.Lstat(fName)
for ; err == nil && num <= 999; num++ {
fName = w.fileNameOnly + fmt.Sprintf(".%s.%03d%s", w.dailyOpenTime.Format("2006-01-02"), num, w.suffix)
_, err = os.Lstat(fName)
}
w.MaxFilesCurFiles = num
}
// return error if the last file checked still existed
if err == nil {
return fmt.Errorf("Rotate: Cannot find free log number to rename %s", w.Filename)
@ -307,13 +369,21 @@ func (w *fileLogWriter) deleteOldLog() {
if info == nil {
return
}
if !info.IsDir() && info.ModTime().Add(24*time.Hour*time.Duration(w.MaxDays)).Before(time.Now()) {
if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) &&
strings.HasSuffix(filepath.Base(path), w.suffix) {
os.Remove(path)
}
}
if w.Hourly {
if !info.IsDir() && info.ModTime().Add(1 * time.Hour * time.Duration(w.MaxHours)).Before(time.Now()) {
if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) &&
strings.HasSuffix(filepath.Base(path), w.suffix) {
os.Remove(path)
}
}
} else if w.Daily {
if !info.IsDir() && info.ModTime().Add(24 * time.Hour * time.Duration(w.MaxDays)).Before(time.Now()) {
if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) &&
strings.HasSuffix(filepath.Base(path), w.suffix) {
os.Remove(path)
}
}
}
return
})
}

View File

@ -92,7 +92,7 @@ type Logger interface {
}
var adapters = make(map[string]newLoggerFunc)
var levelPrefix = [LevelDebug + 1]string{"[M] ", "[A] ", "[C] ", "[E] ", "[W] ", "[N] ", "[I] ", "[D] "}
var levelPrefix = [LevelDebug + 1]string{"[M]", "[A]", "[C]", "[E]", "[W]", "[N]", "[I]", "[D]"}
// Register makes a log provide available by the provided name.
// If Register is called twice with the same name or if driver is nil,
@ -116,6 +116,7 @@ type BeeLogger struct {
enableFuncCallDepth bool
loggerFuncCallDepth int
asynchronous bool
prefix string
msgChanLen int64
msgChan chan *logMsg
signalChan chan string
@ -186,12 +187,12 @@ func (bl *BeeLogger) setLogger(adapterName string, configs ...string) error {
}
}
log, ok := adapters[adapterName]
logAdapter, ok := adapters[adapterName]
if !ok {
return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adapterName)
}
lg := log()
lg := logAdapter()
err := lg.Init(config)
if err != nil {
fmt.Fprintln(os.Stderr, "logs.BeeLogger.SetLogger: "+err.Error())
@ -267,6 +268,9 @@ func (bl *BeeLogger) writeMsg(logLevel int, msg string, v ...interface{}) error
if len(v) > 0 {
msg = fmt.Sprintf(msg, v...)
}
msg = bl.prefix + " " + msg
when := time.Now()
if bl.enableFuncCallDepth {
_, file, line, ok := runtime.Caller(bl.loggerFuncCallDepth)
@ -283,7 +287,7 @@ func (bl *BeeLogger) writeMsg(logLevel int, msg string, v ...interface{}) error
// set to emergency to ensure all log will be print out correctly
logLevel = LevelEmergency
} else {
msg = levelPrefix[logLevel] + msg
msg = levelPrefix[logLevel] + " " + msg
}
if bl.asynchronous {
@ -305,6 +309,11 @@ func (bl *BeeLogger) SetLevel(l int) {
bl.level = l
}
// GetLevel Get Current log message level.
func (bl *BeeLogger) GetLevel() int {
return bl.level
}
// SetLogFuncCallDepth set log funcCallDepth
func (bl *BeeLogger) SetLogFuncCallDepth(d int) {
bl.loggerFuncCallDepth = d
@ -320,6 +329,11 @@ func (bl *BeeLogger) EnableFuncCallDepth(b bool) {
bl.enableFuncCallDepth = b
}
// set prefix
func (bl *BeeLogger) SetPrefix(s string) {
bl.prefix = s
}
// start logger chan reading.
// when chan is not empty, write logs.
func (bl *BeeLogger) startLogger() {
@ -544,6 +558,11 @@ func SetLevel(l int) {
beeLogger.SetLevel(l)
}
// SetPrefix sets the prefix
func SetPrefix(s string) {
beeLogger.SetPrefix(s)
}
// EnableFuncCallDepth enable log funcCallDepth
func EnableFuncCallDepth(b bool) {
beeLogger.enableFuncCallDepth = b

View File

@ -15,9 +15,8 @@
package logs
import (
"fmt"
"io"
"os"
"runtime"
"sync"
"time"
)
@ -31,47 +30,13 @@ func newLogWriter(wr io.Writer) *logWriter {
return &logWriter{writer: wr}
}
func (lg *logWriter) println(when time.Time, msg string) {
func (lg *logWriter) writeln(when time.Time, msg string) {
lg.Lock()
h, _ := formatTimeHeader(when)
h, _, _ := formatTimeHeader(when)
lg.writer.Write(append(append(h, msg...), '\n'))
lg.Unlock()
}
type outputMode int
// DiscardNonColorEscSeq supports the divided color escape sequence.
// But non-color escape sequence is not output.
// Please use the OutputNonColorEscSeq If you want to output a non-color
// escape sequences such as ncurses. However, it does not support the divided
// color escape sequence.
const (
_ outputMode = iota
DiscardNonColorEscSeq
OutputNonColorEscSeq
)
// NewAnsiColorWriter creates and initializes a new ansiColorWriter
// using io.Writer w as its initial contents.
// In the console of Windows, which change the foreground and background
// colors of the text by the escape sequence.
// In the console of other systems, which writes to w all text.
func NewAnsiColorWriter(w io.Writer) io.Writer {
return NewModeAnsiColorWriter(w, DiscardNonColorEscSeq)
}
// NewModeAnsiColorWriter create and initializes a new ansiColorWriter
// by specifying the outputMode.
func NewModeAnsiColorWriter(w io.Writer, mode outputMode) io.Writer {
if _, ok := w.(*ansiColorWriter); !ok {
return &ansiColorWriter{
w: w,
mode: mode,
}
}
return w
}
const (
y1 = `0123456789`
y2 = `0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789`
@ -87,13 +52,15 @@ const (
mi2 = `012345678901234567890123456789012345678901234567890123456789`
s1 = `000000000011111111112222222222333333333344444444445555555555`
s2 = `012345678901234567890123456789012345678901234567890123456789`
ns1 = `0123456789`
)
func formatTimeHeader(when time.Time) ([]byte, int) {
func formatTimeHeader(when time.Time) ([]byte, int, int) {
y, mo, d := when.Date()
h, mi, s := when.Clock()
//len("2006/01/02 15:04:05 ")==20
var buf [20]byte
ns := when.Nanosecond() / 1000000
//len("2006/01/02 15:04:05.123 ")==24
var buf [24]byte
buf[0] = y1[y/1000%10]
buf[1] = y2[y/100]
@ -114,9 +81,14 @@ func formatTimeHeader(when time.Time) ([]byte, int) {
buf[16] = ':'
buf[17] = s1[s]
buf[18] = s2[s]
buf[19] = ' '
buf[19] = '.'
buf[20] = ns1[ns/100]
buf[21] = ns1[ns%100/10]
buf[22] = ns1[ns%10]
return buf[0:], d
buf[23] = ' '
return buf[0:], d, h
}
var (
@ -139,63 +111,65 @@ var (
reset = string([]byte{27, 91, 48, 109})
)
var once sync.Once
var colorMap map[string]string
func initColor() {
if runtime.GOOS == "windows" {
green = w32Green
white = w32White
yellow = w32Yellow
red = w32Red
blue = w32Blue
magenta = w32Magenta
cyan = w32Cyan
}
colorMap = map[string]string{
//by color
"green": green,
"white": white,
"yellow": yellow,
"red": red,
//by method
"GET": blue,
"POST": cyan,
"PUT": yellow,
"DELETE": red,
"PATCH": green,
"HEAD": magenta,
"OPTIONS": white,
}
}
// ColorByStatus return color by http code
// 2xx return Green
// 3xx return White
// 4xx return Yellow
// 5xx return Red
func ColorByStatus(cond bool, code int) string {
func ColorByStatus(code int) string {
once.Do(initColor)
switch {
case code >= 200 && code < 300:
return map[bool]string{true: green, false: w32Green}[cond]
return colorMap["green"]
case code >= 300 && code < 400:
return map[bool]string{true: white, false: w32White}[cond]
return colorMap["white"]
case code >= 400 && code < 500:
return map[bool]string{true: yellow, false: w32Yellow}[cond]
return colorMap["yellow"]
default:
return map[bool]string{true: red, false: w32Red}[cond]
return colorMap["red"]
}
}
// ColorByMethod return color by http code
// GET return Blue
// POST return Cyan
// PUT return Yellow
// DELETE return Red
// PATCH return Green
// HEAD return Magenta
// OPTIONS return WHITE
func ColorByMethod(cond bool, method string) string {
switch method {
case "GET":
return map[bool]string{true: blue, false: w32Blue}[cond]
case "POST":
return map[bool]string{true: cyan, false: w32Cyan}[cond]
case "PUT":
return map[bool]string{true: yellow, false: w32Yellow}[cond]
case "DELETE":
return map[bool]string{true: red, false: w32Red}[cond]
case "PATCH":
return map[bool]string{true: green, false: w32Green}[cond]
case "HEAD":
return map[bool]string{true: magenta, false: w32Magenta}[cond]
case "OPTIONS":
return map[bool]string{true: white, false: w32White}[cond]
default:
return reset
func ColorByMethod(method string) string {
once.Do(initColor)
if c := colorMap[method]; c != "" {
return c
}
return reset
}
// Guard Mutex to guarantee atomic of W32Debug(string) function
var mu sync.Mutex
// W32Debug Helper method to output colored logs in Windows terminals
func W32Debug(msg string) {
mu.Lock()
defer mu.Unlock()
current := time.Now()
w := NewAnsiColorWriter(os.Stdout)
fmt.Fprintf(w, "[beego] %v %s\n", current.Format("2006/01/02 - 15:04:05"), msg)
// ResetColor return reset color
func ResetColor() string {
return reset
}

View File

@ -67,7 +67,10 @@ func (f *multiFileLogWriter) Init(config string) error {
jsonMap["level"] = i
bs, _ := json.Marshal(jsonMap)
writer = newFileWriter().(*fileLogWriter)
writer.Init(string(bs))
err := writer.Init(string(bs))
if err != nil {
return err
}
f.writers[i] = writer
}
}

View File

@ -207,11 +207,11 @@ func (n *Namespace) Include(cList ...ControllerInterface) *Namespace {
func (n *Namespace) Namespace(ns ...*Namespace) *Namespace {
for _, ni := range ns {
for k, v := range ni.handlers.routers {
if t, ok := n.handlers.routers[k]; ok {
if _, ok := n.handlers.routers[k]; ok {
addPrefix(v, ni.prefix)
n.handlers.routers[k].AddTree(ni.prefix, v)
} else {
t = NewTree()
t := NewTree()
t.AddTree(ni.prefix, v)
addPrefix(t, ni.prefix)
n.handlers.routers[k] = t
@ -236,11 +236,11 @@ func (n *Namespace) Namespace(ns ...*Namespace) *Namespace {
func AddNamespace(nl ...*Namespace) {
for _, n := range nl {
for k, v := range n.handlers.routers {
if t, ok := BeeApp.Handlers.routers[k]; ok {
if _, ok := BeeApp.Handlers.routers[k]; ok {
addPrefix(v, n.prefix)
BeeApp.Handlers.routers[k].AddTree(n.prefix, v)
} else {
t = NewTree()
t := NewTree()
t.AddTree(n.prefix, v)
addPrefix(t, n.prefix)
BeeApp.Handlers.routers[k] = t

View File

@ -61,6 +61,9 @@ func init() {
// set default database
orm.RegisterDataBase("default", "mysql", "root:root@/my_db?charset=utf8", 30)
// create table
orm.RunSyncdb("default", false, true)
}
func main() {

View File

@ -51,12 +51,14 @@ checkColumn:
switch fieldType {
case TypeBooleanField:
col = T["bool"]
case TypeCharField:
case TypeVarCharField:
if al.Driver == DRPostgres && fi.toText {
col = T["string-text"]
} else {
col = fmt.Sprintf(T["string"], fieldSize)
}
case TypeCharField:
col = fmt.Sprintf(T["string-char"], fieldSize)
case TypeTextField:
col = T["string-text"]
case TypeTimeField:
@ -96,13 +98,13 @@ checkColumn:
}
case TypeJSONField:
if al.Driver != DRPostgres {
fieldType = TypeCharField
fieldType = TypeVarCharField
goto checkColumn
}
col = T["json"]
case TypeJsonbField:
if al.Driver != DRPostgres {
fieldType = TypeCharField
fieldType = TypeVarCharField
goto checkColumn
}
col = T["jsonb"]
@ -195,6 +197,10 @@ func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex
if strings.Contains(column, "%COL%") {
column = strings.Replace(column, "%COL%", fi.column, -1)
}
if fi.description != "" {
column += " " + fmt.Sprintf("COMMENT '%s'",fi.description)
}
columns = append(columns, column)
}

View File

@ -142,7 +142,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
} else {
value = field.Bool()
}
case TypeCharField, TypeTextField, TypeJSONField, TypeJsonbField:
case TypeVarCharField, TypeCharField, TypeTextField, TypeJSONField, TypeJsonbField:
if ns, ok := field.Interface().(sql.NullString); ok {
value = nil
if ns.Valid {
@ -536,6 +536,8 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a
updates := make([]string, len(names))
var conflitValue interface{}
for i, v := range names {
// identifier in database may not be case-sensitive, so quote it
v = fmt.Sprintf("%s%s%s", Q, v, Q)
marks[i] = "?"
valueStr := argsMap[strings.ToLower(v)]
if v == args0 {
@ -619,6 +621,31 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
return 0, err
}
var findAutoNowAdd, findAutoNow bool
var index int
for i, col := range setNames {
if mi.fields.GetByColumn(col).autoNowAdd {
index = i
findAutoNowAdd = true
}
if mi.fields.GetByColumn(col).autoNow {
findAutoNow = true
}
}
if findAutoNowAdd {
setNames = append(setNames[0:index], setNames[index+1:]...)
setValues = append(setValues[0:index], setValues[index+1:]...)
}
if !findAutoNow {
for col, info := range mi.fields.columns {
if info.autoNow {
setNames = append(setNames, col)
setValues = append(setValues, time.Now())
}
}
}
setValues = append(setValues, pkValue)
Q := d.ins.TableQuote()
@ -760,7 +787,13 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
}
d.ins.ReplaceMarks(&query)
res, err := q.Exec(query, values...)
var err error
var res sql.Result
if qs != nil && qs.forContext {
res, err = q.ExecContext(qs.ctx, query, values...)
} else {
res, err = q.Exec(query, values...)
}
if err == nil {
return res.RowsAffected()
}
@ -849,11 +882,16 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
for i := range marks {
marks[i] = "?"
}
sql := fmt.Sprintf("IN (%s)", strings.Join(marks, ", "))
query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sql)
sqlIn := fmt.Sprintf("IN (%s)", strings.Join(marks, ", "))
query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sqlIn)
d.ins.ReplaceMarks(&query)
res, err := q.Exec(query, args...)
var res sql.Result
if qs != nil && qs.forContext {
res, err = q.ExecContext(qs.ctx, query, args...)
} else {
res, err = q.Exec(query, args...)
}
if err == nil {
num, err := res.RowsAffected()
if err != nil {
@ -926,7 +964,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
maps[fi.column] = true
}
} else {
panic(fmt.Errorf("wrong field/column name `%s`", col))
return 0, fmt.Errorf("wrong field/column name `%s`", col)
}
}
if hasRel {
@ -969,14 +1007,25 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
}
query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s", sqlSelect, sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit)
if qs.forupdate {
query += " FOR UPDATE"
}
d.ins.ReplaceMarks(&query)
var rs *sql.Rows
r, err := q.Query(query, args...)
if err != nil {
return 0, err
var err error
if qs != nil && qs.forContext {
rs, err = q.QueryContext(qs.ctx, query, args...)
if err != nil {
return 0, err
}
} else {
rs, err = q.Query(query, args...)
if err != nil {
return 0, err
}
}
rs = r
refs := make([]interface{}, colsNum)
for i := range refs {
@ -1105,8 +1154,12 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
d.ins.ReplaceMarks(&query)
row := q.QueryRow(query, args...)
var row *sql.Row
if qs != nil && qs.forContext {
row = q.QueryRowContext(qs.ctx, query, args...)
} else {
row = q.QueryRow(query, args...)
}
err = row.Scan(&cnt)
return
}
@ -1240,7 +1293,7 @@ setValue:
}
value = b
}
case fieldType == TypeCharField || fieldType == TypeTextField || fieldType == TypeJSONField || fieldType == TypeJsonbField:
case fieldType == TypeVarCharField || fieldType == TypeCharField || fieldType == TypeTextField || fieldType == TypeJSONField || fieldType == TypeJsonbField:
if str == nil {
value = ToStr(val)
} else {
@ -1386,7 +1439,7 @@ setValue:
field.SetBool(value.(bool))
}
}
case fieldType == TypeCharField || fieldType == TypeTextField || fieldType == TypeJSONField || fieldType == TypeJsonbField:
case fieldType == TypeVarCharField || fieldType == TypeCharField || fieldType == TypeTextField || fieldType == TypeJSONField || fieldType == TypeJsonbField:
if isNative {
if ns, ok := field.Interface().(sql.NullString); ok {
if value == nil {

View File

@ -15,6 +15,7 @@
package orm
import (
"context"
"database/sql"
"fmt"
"reflect"
@ -103,6 +104,96 @@ func (ac *_dbCache) getDefault() (al *alias) {
return
}
type DB struct {
*sync.RWMutex
DB *sql.DB
stmts map[string]*sql.Stmt
}
func (d *DB) Begin() (*sql.Tx, error) {
return d.DB.Begin()
}
func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
return d.DB.BeginTx(ctx, opts)
}
func (d *DB) getStmt(query string) (*sql.Stmt, error) {
d.RLock()
if stmt, ok := d.stmts[query]; ok {
d.RUnlock()
return stmt, nil
}
d.RUnlock()
stmt, err := d.Prepare(query)
if err != nil {
return nil, err
}
d.Lock()
d.stmts[query] = stmt
d.Unlock()
return stmt, nil
}
func (d *DB) Prepare(query string) (*sql.Stmt, error) {
return d.DB.Prepare(query)
}
func (d *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
return d.DB.PrepareContext(ctx, query)
}
func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) {
stmt, err := d.getStmt(query)
if err != nil {
return nil, err
}
return stmt.Exec(args...)
}
func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
stmt, err := d.getStmt(query)
if err != nil {
return nil, err
}
return stmt.ExecContext(ctx, args...)
}
func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) {
stmt, err := d.getStmt(query)
if err != nil {
return nil, err
}
return stmt.Query(args...)
}
func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
stmt, err := d.getStmt(query)
if err != nil {
return nil, err
}
return stmt.QueryContext(ctx, args...)
}
func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row {
stmt, err := d.getStmt(query)
if err != nil {
panic(err)
}
return stmt.QueryRow(args...)
}
func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
stmt, err := d.getStmt(query)
if err != nil {
panic(err)
}
return stmt.QueryRowContext(ctx, args)
}
type alias struct {
Name string
Driver DriverType
@ -110,7 +201,7 @@ type alias struct {
DataSource string
MaxIdleConns int
MaxOpenConns int
DB *sql.DB
DB *DB
DbBaser dbBaser
TZ *time.Location
Engine string
@ -119,7 +210,7 @@ type alias struct {
func detectTZ(al *alias) {
// orm timezone system match database
// default use Local
al.TZ = time.Local
al.TZ = DefaultTimeLoc
if al.DriverName == "sphinx" {
return
@ -136,7 +227,9 @@ func detectTZ(al *alias) {
}
t, err := time.Parse("-07:00:00", tz)
if err == nil {
al.TZ = t.Location()
if t.Location().String() != "" {
al.TZ = t.Location()
}
} else {
DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
}
@ -174,7 +267,11 @@ func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) {
al := new(alias)
al.Name = aliasName
al.DriverName = driverName
al.DB = db
al.DB = &DB{
RWMutex: new(sync.RWMutex),
DB: db,
stmts: make(map[string]*sql.Stmt),
}
if dr, ok := drivers[driverName]; ok {
al.DbBaser = dbBasers[dr]
@ -270,7 +367,7 @@ func SetDataBaseTZ(aliasName string, tz *time.Location) error {
func SetMaxIdleConns(aliasName string, maxIdleConns int) {
al := getDbAlias(aliasName)
al.MaxIdleConns = maxIdleConns
al.DB.SetMaxIdleConns(maxIdleConns)
al.DB.DB.SetMaxIdleConns(maxIdleConns)
}
// SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name
@ -294,7 +391,7 @@ func GetDB(aliasNames ...string) (*sql.DB, error) {
}
al, ok := dataBaseCache.get(name)
if ok {
return al.DB, nil
return al.DB.DB, nil
}
return nil, fmt.Errorf("DataBase of alias name `%s` not found", name)
}

View File

@ -46,6 +46,7 @@ var mysqlTypes = map[string]string{
"pk": "NOT NULL PRIMARY KEY",
"bool": "bool",
"string": "varchar(%d)",
"string-char": "char(%d)",
"string-text": "longtext",
"time.Time-date": "date",
"time.Time": "datetime",

View File

@ -34,6 +34,7 @@ var oracleTypes = map[string]string{
"pk": "NOT NULL PRIMARY KEY",
"bool": "bool",
"string": "VARCHAR2(%d)",
"string-char": "CHAR(%d)",
"string-text": "VARCHAR2(%d)",
"time.Time-date": "DATE",
"time.Time": "TIMESTAMP",

View File

@ -43,6 +43,7 @@ var postgresTypes = map[string]string{
"pk": "NOT NULL PRIMARY KEY",
"bool": "bool",
"string": "varchar(%d)",
"string-char": "char(%d)",
"string-text": "text",
"time.Time-date": "date",
"time.Time": "timestamp with time zone",

View File

@ -43,6 +43,7 @@ var sqliteTypes = map[string]string{
"pk": "NOT NULL PRIMARY KEY",
"bool": "bool",
"string": "varchar(%d)",
"string-char": "character(%d)",
"string-text": "text",
"time.Time-date": "date",
"time.Time": "datetime",

View File

@ -372,7 +372,13 @@ func (t *dbTables) getCondSQL(cond *Condition, sub bool, tz *time.Location) (whe
operator = "exact"
}
operSQL, args := t.base.GenerateOperatorSQL(mi, fi, operator, p.args, tz)
var operSQL string
var args []interface{}
if p.isRaw {
operSQL = p.sql
} else {
operSQL, args = t.base.GenerateOperatorSQL(mi, fi, operator, p.args, tz)
}
leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q)
t.base.GenerateOperatorLeftCol(fi, operator, &leftCol)

View File

@ -52,7 +52,7 @@ func (mc *_modelCache) all() map[string]*modelInfo {
return m
}
// get orderd model info
// get ordered model info
func (mc *_modelCache) allOrdered() []*modelInfo {
m := make([]*modelInfo, 0, len(mc.orders))
for _, table := range mc.orders {

View File

@ -89,7 +89,7 @@ func registerModel(PrefixOrSuffix string, model interface{}, isPrefix bool) {
modelCache.set(table, mi)
}
// boostrap models
// bootstrap models
func bootStrap() {
if modelCache.done {
return
@ -332,14 +332,14 @@ func RegisterModelWithSuffix(suffix string, models ...interface{}) {
}
}
// BootStrap bootrap models.
// BootStrap bootstrap models.
// make all model parsed and can not add more models
func BootStrap() {
modelCache.Lock()
defer modelCache.Unlock()
if modelCache.done {
return
}
modelCache.Lock()
defer modelCache.Unlock()
bootStrap()
modelCache.done = true
}

View File

@ -23,6 +23,7 @@ import (
// Define the Type enum
const (
TypeBooleanField = 1 << iota
TypeVarCharField
TypeCharField
TypeTextField
TypeTimeField
@ -49,9 +50,9 @@ const (
// Define some logic enum
const (
IsIntegerField = ^-TypePositiveBigIntegerField >> 5 << 6
IsPositiveIntegerField = ^-TypePositiveBigIntegerField >> 9 << 10
IsRelField = ^-RelReverseMany >> 17 << 18
IsIntegerField = ^-TypePositiveBigIntegerField >> 6 << 7
IsPositiveIntegerField = ^-TypePositiveBigIntegerField >> 10 << 11
IsRelField = ^-RelReverseMany >> 18 << 19
IsFieldType = ^-RelReverseMany<<1 + 1
)
@ -85,7 +86,7 @@ func (e *BooleanField) SetRaw(value interface{}) error {
e.Set(d)
case string:
v, err := StrTo(d).Bool()
if err != nil {
if err == nil {
e.Set(v)
}
return err
@ -126,7 +127,7 @@ func (e *CharField) String() string {
// FieldType return the enum type
func (e *CharField) FieldType() int {
return TypeCharField
return TypeVarCharField
}
// SetRaw set the interface to string
@ -190,7 +191,7 @@ func (e *TimeField) SetRaw(value interface{}) error {
e.Set(d)
case string:
v, err := timeParse(d, formatTime)
if err != nil {
if err == nil {
e.Set(v)
}
return err
@ -232,7 +233,7 @@ func (e *DateField) Set(d time.Time) {
*e = DateField(d)
}
// String convert datatime to string
// String convert datetime to string
func (e *DateField) String() string {
return e.Value().String()
}
@ -249,7 +250,7 @@ func (e *DateField) SetRaw(value interface{}) error {
e.Set(d)
case string:
v, err := timeParse(d, formatDate)
if err != nil {
if err == nil {
e.Set(v)
}
return err
@ -272,12 +273,12 @@ var _ Fielder = new(DateField)
// Takes the same extra arguments as DateField.
type DateTimeField time.Time
// Value return the datatime value
// Value return the datetime value
func (e DateTimeField) Value() time.Time {
return time.Time(e)
}
// Set set the time.Time to datatime
// Set set the time.Time to datetime
func (e *DateTimeField) Set(d time.Time) {
*e = DateTimeField(d)
}
@ -299,7 +300,7 @@ func (e *DateTimeField) SetRaw(value interface{}) error {
e.Set(d)
case string:
v, err := timeParse(d, formatDateTime)
if err != nil {
if err == nil {
e.Set(v)
}
return err
@ -309,12 +310,12 @@ func (e *DateTimeField) SetRaw(value interface{}) error {
return nil
}
// RawValue return the datatime value
// RawValue return the datetime value
func (e *DateTimeField) RawValue() interface{} {
return e.Value()
}
// verify datatime implement fielder
// verify datetime implement fielder
var _ Fielder = new(DateTimeField)
// FloatField A floating-point number represented in go by a float32 value.
@ -349,9 +350,10 @@ func (e *FloatField) SetRaw(value interface{}) error {
e.Set(d)
case string:
v, err := StrTo(d).Float64()
if err != nil {
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<FloatField.SetRaw> unknown value `%s`", value)
}
@ -396,9 +398,10 @@ func (e *SmallIntegerField) SetRaw(value interface{}) error {
e.Set(d)
case string:
v, err := StrTo(d).Int16()
if err != nil {
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<SmallIntegerField.SetRaw> unknown value `%s`", value)
}
@ -443,9 +446,10 @@ func (e *IntegerField) SetRaw(value interface{}) error {
e.Set(d)
case string:
v, err := StrTo(d).Int32()
if err != nil {
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<IntegerField.SetRaw> unknown value `%s`", value)
}
@ -490,9 +494,10 @@ func (e *BigIntegerField) SetRaw(value interface{}) error {
e.Set(d)
case string:
v, err := StrTo(d).Int64()
if err != nil {
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<BigIntegerField.SetRaw> unknown value `%s`", value)
}
@ -537,9 +542,10 @@ func (e *PositiveSmallIntegerField) SetRaw(value interface{}) error {
e.Set(d)
case string:
v, err := StrTo(d).Uint16()
if err != nil {
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<PositiveSmallIntegerField.SetRaw> unknown value `%s`", value)
}
@ -584,9 +590,10 @@ func (e *PositiveIntegerField) SetRaw(value interface{}) error {
e.Set(d)
case string:
v, err := StrTo(d).Uint32()
if err != nil {
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<PositiveIntegerField.SetRaw> unknown value `%s`", value)
}
@ -631,9 +638,10 @@ func (e *PositiveBigIntegerField) SetRaw(value interface{}) error {
e.Set(d)
case string:
v, err := StrTo(d).Uint64()
if err != nil {
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<PositiveBigIntegerField.SetRaw> unknown value `%s`", value)
}

View File

@ -136,6 +136,7 @@ type fieldInfo struct {
decimals int
isFielder bool // implement Fielder interface
onDelete string
description string
}
// new field info
@ -244,8 +245,10 @@ checkType:
if err != nil {
goto end
}
if fieldType == TypeCharField {
if fieldType == TypeVarCharField {
switch tags["type"] {
case "char":
fieldType = TypeCharField
case "text":
fieldType = TypeTextField
case "json":
@ -298,6 +301,7 @@ checkType:
fi.sf = sf
fi.fullName = mi.fullName + mName + "." + sf.Name
fi.description = tags["description"]
fi.null = attrs["null"]
fi.index = attrs["index"]
fi.auto = attrs["auto"]
@ -357,7 +361,7 @@ checkType:
switch fieldType {
case TypeBooleanField:
case TypeCharField, TypeJSONField, TypeJsonbField:
case TypeVarCharField, TypeCharField, TypeJSONField, TypeJsonbField:
if size != "" {
v, e := StrTo(size).Int32()
if e != nil {

View File

@ -75,7 +75,8 @@ func addModelFields(mi *modelInfo, ind reflect.Value, mName string, index []int)
break
}
//record current field index
fi.fieldIndex = append(index, i)
fi.fieldIndex = append(fi.fieldIndex, index...)
fi.fieldIndex = append(fi.fieldIndex, i)
fi.mi = mi
fi.inModel = true
if !mi.fields.Add(fi) {

View File

@ -44,6 +44,7 @@ var supportTag = map[string]int{
"decimals": 2,
"on_delete": 2,
"type": 2,
"description": 2,
}
// get reflect.Type name with package path.
@ -65,7 +66,7 @@ func getTableName(val reflect.Value) string {
return snakeString(reflect.Indirect(val).Type().Name())
}
// get table engine, mysiam or innodb.
// get table engine, myisam or innodb.
func getTableEngine(val reflect.Value) string {
fun := val.MethodByName("TableEngine")
if fun.IsValid() {
@ -109,7 +110,7 @@ func getTableUnique(val reflect.Value) [][]string {
func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string {
column := col
if col == "" {
column = snakeString(sf.Name)
column = nameStrategyMap[nameStrategy](sf.Name)
}
switch ft {
case RelForeignKey, RelOneToOne:
@ -149,7 +150,7 @@ func getFieldType(val reflect.Value) (ft int, err error) {
case reflect.TypeOf(new(bool)):
ft = TypeBooleanField
case reflect.TypeOf(new(string)):
ft = TypeCharField
ft = TypeVarCharField
case reflect.TypeOf(new(time.Time)):
ft = TypeDateTimeField
default:
@ -176,7 +177,7 @@ func getFieldType(val reflect.Value) (ft int, err error) {
case reflect.Bool:
ft = TypeBooleanField
case reflect.String:
ft = TypeCharField
ft = TypeVarCharField
default:
if elm.Interface() == nil {
panic(fmt.Errorf("%s is nil pointer, may be miss setting tag", val))
@ -189,7 +190,7 @@ func getFieldType(val reflect.Value) (ft int, err error) {
case sql.NullBool:
ft = TypeBooleanField
case sql.NullString:
ft = TypeCharField
ft = TypeVarCharField
case time.Time:
ft = TypeDateTimeField
}

View File

@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// +build go1.8
// Package orm provide ORM for MySQL/PostgreSQL/sqlite
// Simple Usage
//
@ -52,11 +54,13 @@
package orm
import (
"context"
"database/sql"
"errors"
"fmt"
"os"
"reflect"
"sync"
"time"
)
@ -69,7 +73,7 @@ const (
var (
Debug = false
DebugLog = NewLog(os.Stdout)
DefaultRowsLimit = 1000
DefaultRowsLimit = -1
DefaultRelsDepth = 2
DefaultTimeLoc = time.Local
ErrTxHasBegan = errors.New("<Ormer.Begin> transaction already begin")
@ -422,7 +426,7 @@ func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
var name string
if table, ok := ptrStructOrTableName.(string); ok {
name = snakeString(table)
name = nameStrategyMap[defaultNameStrategy](table)
if mi, ok := modelCache.get(name); ok {
qs = newQuerySet(o, mi)
}
@ -458,11 +462,15 @@ func (o *orm) Using(name string) error {
// begin transaction
func (o *orm) Begin() error {
return o.BeginTx(context.Background(), nil)
}
func (o *orm) BeginTx(ctx context.Context, opts *sql.TxOptions) error {
if o.isTx {
return ErrTxHasBegan
}
var tx *sql.Tx
tx, err := o.db.(txer).Begin()
tx, err := o.db.(txer).BeginTx(ctx, opts)
if err != nil {
return err
}
@ -515,6 +523,15 @@ func (o *orm) Driver() Driver {
return driver(o.alias.Name)
}
// return sql.DBStats for current database
func (o *orm) DBStats() *sql.DBStats {
if o.alias != nil && o.alias.DB != nil {
stats := o.alias.DB.DB.Stats()
return &stats
}
return nil
}
// NewOrm create new orm
func NewOrm() Ormer {
BootStrap() // execute only once
@ -541,6 +558,13 @@ func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) {
al.Name = aliasName
al.DriverName = driverName
al.DB = &DB{
RWMutex: new(sync.RWMutex),
DB: db,
stmts: make(map[string]*sql.Stmt),
}
detectTZ(al)
o := new(orm)
o.alias = al

View File

@ -31,6 +31,8 @@ type condValue struct {
isOr bool
isNot bool
isCond bool
isRaw bool
sql string
}
// Condition struct.
@ -45,6 +47,15 @@ func NewCondition() *Condition {
return c
}
// Raw add raw sql to condition
func (c Condition) Raw(expr string, sql string) *Condition {
if len(sql) == 0 {
panic(fmt.Errorf("<Condition.Raw> sql cannot empty"))
}
c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), sql: sql, isRaw: true})
return &c
}
// And add expression to condition
func (c Condition) And(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 {

View File

@ -15,6 +15,7 @@
package orm
import (
"context"
"database/sql"
"fmt"
"io"
@ -28,6 +29,9 @@ type Log struct {
*log.Logger
}
//costomer log func
var LogFunc func(query map[string]interface{})
// NewLog set io.Writer to create a Logger.
func NewLog(out io.Writer) *Log {
d := new(Log)
@ -36,12 +40,15 @@ func NewLog(out io.Writer) *Log {
}
func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error, args ...interface{}) {
var logMap = make(map[string]interface{})
sub := time.Now().Sub(t) / 1e5
elsp := float64(int(sub)) / 10.0
logMap["cost_time"] = elsp
flag := " OK"
if err != nil {
flag = "FAIL"
}
logMap["flag"] = flag
con := fmt.Sprintf(" -[Queries/%s] - [%s / %11s / %7.1fms] - [%s]", alias.Name, flag, operaton, elsp, query)
cons := make([]string, 0, len(args))
for _, arg := range args {
@ -53,6 +60,10 @@ func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error
if err != nil {
con += " - " + err.Error()
}
logMap["sql"] = fmt.Sprintf("%s-`%s`", query, strings.Join(cons, "`, `"))
if LogFunc != nil{
LogFunc(logMap)
}
DebugLog.Println(con)
}
@ -122,6 +133,13 @@ func (d *dbQueryLog) Prepare(query string) (*sql.Stmt, error) {
return stmt, err
}
func (d *dbQueryLog) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
a := time.Now()
stmt, err := d.db.PrepareContext(ctx, query)
debugLogQueies(d.alias, "db.Prepare", query, a, err)
return stmt, err
}
func (d *dbQueryLog) Exec(query string, args ...interface{}) (sql.Result, error) {
a := time.Now()
res, err := d.db.Exec(query, args...)
@ -129,6 +147,13 @@ func (d *dbQueryLog) Exec(query string, args ...interface{}) (sql.Result, error)
return res, err
}
func (d *dbQueryLog) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
a := time.Now()
res, err := d.db.ExecContext(ctx, query, args...)
debugLogQueies(d.alias, "db.Exec", query, a, err, args...)
return res, err
}
func (d *dbQueryLog) Query(query string, args ...interface{}) (*sql.Rows, error) {
a := time.Now()
res, err := d.db.Query(query, args...)
@ -136,6 +161,13 @@ func (d *dbQueryLog) Query(query string, args ...interface{}) (*sql.Rows, error)
return res, err
}
func (d *dbQueryLog) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
a := time.Now()
res, err := d.db.QueryContext(ctx, query, args...)
debugLogQueies(d.alias, "db.Query", query, a, err, args...)
return res, err
}
func (d *dbQueryLog) QueryRow(query string, args ...interface{}) *sql.Row {
a := time.Now()
res := d.db.QueryRow(query, args...)
@ -143,6 +175,13 @@ func (d *dbQueryLog) QueryRow(query string, args ...interface{}) *sql.Row {
return res
}
func (d *dbQueryLog) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
a := time.Now()
res := d.db.QueryRowContext(ctx, query, args...)
debugLogQueies(d.alias, "db.QueryRow", query, a, nil, args...)
return res
}
func (d *dbQueryLog) Begin() (*sql.Tx, error) {
a := time.Now()
tx, err := d.db.(txer).Begin()
@ -150,6 +189,13 @@ func (d *dbQueryLog) Begin() (*sql.Tx, error) {
return tx, err
}
func (d *dbQueryLog) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
a := time.Now()
tx, err := d.db.(txer).BeginTx(ctx, opts)
debugLogQueies(d.alias, "db.BeginTx", "START TRANSACTION", a, err)
return tx, err
}
func (d *dbQueryLog) Commit() error {
a := time.Now()
err := d.db.(txEnder).Commit()

View File

@ -15,6 +15,7 @@
package orm
import (
"context"
"fmt"
)
@ -55,16 +56,19 @@ func ColValue(opt operator, value interface{}) interface{} {
// real query struct
type querySet struct {
mi *modelInfo
cond *Condition
related []string
relDepth int
limit int64
offset int64
groups []string
orders []string
distinct bool
orm *orm
mi *modelInfo
cond *Condition
related []string
relDepth int
limit int64
offset int64
groups []string
orders []string
distinct bool
forupdate bool
orm *orm
ctx context.Context
forContext bool
}
var _ QuerySeter = new(querySet)
@ -78,6 +82,15 @@ func (o querySet) Filter(expr string, args ...interface{}) QuerySeter {
return &o
}
// add raw sql to querySeter.
func (o querySet) FilterRaw(expr string, sql string) QuerySeter {
if o.cond == nil {
o.cond = NewCondition()
}
o.cond = o.cond.Raw(expr, sql)
return &o
}
// add NOT condition to querySeter.
func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter {
if o.cond == nil {
@ -127,6 +140,12 @@ func (o querySet) Distinct() QuerySeter {
return &o
}
// add FOR UPDATE to SELECT
func (o querySet) ForUpdate() QuerySeter {
o.forupdate = true
return &o
}
// set relation model to query together.
// it will query relation models and assign to parent model.
func (o querySet) RelatedSel(params ...interface{}) QuerySeter {
@ -259,6 +278,13 @@ func (o *querySet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string)
panic(ErrNotImplement)
}
// set context to QuerySeter.
func (o querySet) WithContext(ctx context.Context) QuerySeter {
o.ctx = ctx
o.forContext = true
return &o
}
// create new QuerySeter.
func newQuerySet(orm *orm, mi *modelInfo) QuerySeter {
o := new(querySet)

View File

@ -150,8 +150,10 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
case reflect.Struct:
if value == nil {
ind.Set(reflect.Zero(ind.Type()))
} else if _, ok := ind.Interface().(time.Time); ok {
return
}
switch ind.Interface().(type) {
case time.Time:
var str string
switch d := value.(type) {
case time.Time:
@ -178,7 +180,25 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
}
}
}
case sql.NullString, sql.NullInt64, sql.NullFloat64, sql.NullBool:
indi := reflect.New(ind.Type()).Interface()
sc, ok := indi.(sql.Scanner)
if !ok {
return
}
err := sc.Scan(value)
if err == nil {
ind.Set(reflect.Indirect(reflect.ValueOf(sc)))
}
}
case reflect.Ptr:
if value == nil {
ind.Set(reflect.Zero(ind.Type()))
break
}
ind.Set(reflect.New(ind.Type().Elem()))
o.setFieldValue(reflect.Indirect(ind), value)
}
}
@ -358,7 +378,7 @@ func (o *rawSet) QueryRow(containers ...interface{}) error {
_, tags := parseStructTag(fe.Tag.Get(defaultStructTagName))
var col string
if col = tags["column"]; col == "" {
col = snakeString(fe.Name)
col = nameStrategyMap[nameStrategy](fe.Name)
}
if v, ok := columnsMp[col]; ok {
value := reflect.ValueOf(v).Elem().Interface()
@ -509,7 +529,7 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
_, tags := parseStructTag(fe.Tag.Get(defaultStructTagName))
var col string
if col = tags["column"]; col == "" {
col = snakeString(fe.Name)
col = nameStrategyMap[nameStrategy](fe.Name)
}
if v, ok := columnsMp[col]; ok {
value := reflect.ValueOf(v).Elem().Interface()

View File

@ -15,6 +15,7 @@
package orm
import (
"context"
"database/sql"
"reflect"
"time"
@ -54,7 +55,7 @@ type Ormer interface {
// for example:
// user := new(User)
// id, err = Ormer.Insert(user)
// user must a pointer and Insert will set user's pk field
// user must be a pointer and Insert will set user's pk field
Insert(interface{}) (int64, error)
// mysql:InsertOrUpdate(model) or InsertOrUpdate(model,"colu=colu+value")
// if colu type is integer : can use(+-*/), string : convert(colu,"value")
@ -106,6 +107,17 @@ type Ormer interface {
// ...
// err = o.Rollback()
Begin() error
// begin transaction with provided context and option
// the provided context is used until the transaction is committed or rolled back.
// if the context is canceled, the transaction will be rolled back.
// the provided TxOptions is optional and may be nil if defaults should be used.
// if a non-default isolation level is used that the driver doesn't support, an error will be returned.
// for example:
// o := NewOrm()
// err := o.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead})
// ...
// err = o.Rollback()
BeginTx(ctx context.Context, opts *sql.TxOptions) error
// commit transaction
Commit() error
// rollback transaction
@ -116,6 +128,7 @@ type Ormer interface {
// // update user testing's name to slene
Raw(query string, args ...interface{}) RawSeter
Driver() Driver
DBStats() *sql.DBStats
}
// Inserter insert prepared statement
@ -135,6 +148,11 @@ type QuerySeter interface {
// // time compare
// qs.Filter("created", time.Now())
Filter(string, ...interface{}) QuerySeter
// add raw sql to querySeter.
// for example:
// qs.FilterRaw("user_id IN (SELECT id FROM profile WHERE age>=18)")
// //sql-> WHERE user_id IN (SELECT id FROM profile WHERE age>=18)
FilterRaw(string, string) QuerySeter
// add NOT condition to querySeter.
// have the same usage as Filter
Exclude(string, ...interface{}) QuerySeter
@ -190,6 +208,10 @@ type QuerySeter interface {
// Distinct().
// All(&permissions)
Distinct() QuerySeter
// set FOR UPDATE to query.
// for example:
// o.QueryTable("user").Filter("uid", uid).ForUpdate().All(&users)
ForUpdate() QuerySeter
// return QuerySeter execution result number
// for example:
// num, err = qs.Filter("profile__age__gt", 28).Count()
@ -374,16 +396,23 @@ type RawSeter interface {
type stmtQuerier interface {
Close() error
Exec(args ...interface{}) (sql.Result, error)
//ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error)
Query(args ...interface{}) (*sql.Rows, error)
//QueryContext(args ...interface{}) (*sql.Rows, error)
QueryRow(args ...interface{}) *sql.Row
//QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row
}
// db querier
type dbQuerier interface {
Prepare(query string) (*sql.Stmt, error)
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
Exec(query string, args ...interface{}) (sql.Result, error)
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}
// type DB interface {
@ -397,6 +426,7 @@ type dbQuerier interface {
// transaction beginner
type txer interface {
Begin() (*sql.Tx, error)
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
}
// transaction ending

View File

@ -23,6 +23,18 @@ import (
"time"
)
type fn func(string) string
var (
nameStrategyMap = map[string]fn{
defaultNameStrategy: snakeString,
SnakeAcronymNameStrategy: snakeStringWithAcronym,
}
defaultNameStrategy = "snakeString"
SnakeAcronymNameStrategy = "snakeStringWithAcronym"
nameStrategy = defaultNameStrategy
)
// StrTo is the target string
type StrTo string
@ -198,7 +210,28 @@ func ToInt64(value interface{}) (d int64) {
return
}
// snake string, XxYy to xx_yy , XxYY to xx_yy
func snakeStringWithAcronym(s string) string {
data := make([]byte, 0, len(s)*2)
num := len(s)
for i := 0; i < num; i++ {
d := s[i]
before := false
after := false
if i > 0 {
before = s[i-1] >= 'a' && s[i-1] <= 'z'
}
if i+1 < num {
after = s[i+1] >= 'a' && s[i+1] <= 'z'
}
if i > 0 && d >= 'A' && d <= 'Z' && (before || after) {
data = append(data, '_')
}
data = append(data, d)
}
return strings.ToLower(string(data[:]))
}
// snake string, XxYy to xx_yy , XxYY to xx_y_y
func snakeString(s string) string {
data := make([]byte, 0, len(s)*2)
j := false
@ -216,6 +249,14 @@ func snakeString(s string) string {
return strings.ToLower(string(data[:]))
}
// SetNameStrategy set different name strategy
func SetNameStrategy(s string) {
if SnakeAcronymNameStrategy != s {
nameStrategy = defaultNameStrategy
}
nameStrategy = s
}
// camel string, xx_yy to XxYy
func camelString(s string) string {
data := make([]byte, 0, len(s))

View File

@ -35,11 +35,11 @@ import (
"github.com/astaxie/beego/utils"
)
var globalRouterTemplate = `package routers
var globalRouterTemplate = `package {{.routersDir}}
import (
"github.com/astaxie/beego"
"github.com/astaxie/beego/context/param"
"github.com/astaxie/beego/context/param"{{.globalimport}}
)
func init() {
@ -52,6 +52,22 @@ var (
commentFilename string
pkgLastupdate map[string]int64
genInfoList map[string][]ControllerComments
routerHooks = map[string]int{
"beego.BeforeStatic": BeforeStatic,
"beego.BeforeRouter": BeforeRouter,
"beego.BeforeExec": BeforeExec,
"beego.AfterExec": AfterExec,
"beego.FinishRouter": FinishRouter,
}
routerHooksMapping = map[int]string{
BeforeStatic: "beego.BeforeStatic",
BeforeRouter: "beego.BeforeRouter",
BeforeExec: "beego.BeforeExec",
AfterExec: "beego.AfterExec",
FinishRouter: "beego.FinishRouter",
}
)
const commentPrefix = "commentsRouter_"
@ -102,6 +118,20 @@ type parsedComment struct {
routerPath string
methods []string
params map[string]parsedParam
filters []parsedFilter
imports []parsedImport
}
type parsedImport struct {
importPath string
importAlias string
}
type parsedFilter struct {
pattern string
pos int
filter string
params []bool
}
type parsedParam struct {
@ -114,24 +144,69 @@ type parsedParam struct {
func parserComments(f *ast.FuncDecl, controllerName, pkgpath string) error {
if f.Doc != nil {
parsedComment, err := parseComment(f.Doc.List)
parsedComments, err := parseComment(f.Doc.List)
if err != nil {
return err
}
if parsedComment.routerPath != "" {
key := pkgpath + ":" + controllerName
cc := ControllerComments{}
cc.Method = f.Name.String()
cc.Router = parsedComment.routerPath
cc.AllowHTTPMethods = parsedComment.methods
cc.MethodParams = buildMethodParams(f.Type.Params.List, parsedComment)
genInfoList[key] = append(genInfoList[key], cc)
for _, parsedComment := range parsedComments {
if parsedComment.routerPath != "" {
key := pkgpath + ":" + controllerName
cc := ControllerComments{}
cc.Method = f.Name.String()
cc.Router = parsedComment.routerPath
cc.AllowHTTPMethods = parsedComment.methods
cc.MethodParams = buildMethodParams(f.Type.Params.List, parsedComment)
cc.FilterComments = buildFilters(parsedComment.filters)
cc.ImportComments = buildImports(parsedComment.imports)
genInfoList[key] = append(genInfoList[key], cc)
}
}
}
return nil
}
func buildImports(pis []parsedImport) []*ControllerImportComments {
var importComments []*ControllerImportComments
for _, pi := range pis {
importComments = append(importComments, &ControllerImportComments{
ImportPath: pi.importPath,
ImportAlias: pi.importAlias,
})
}
return importComments
}
func buildFilters(pfs []parsedFilter) []*ControllerFilterComments {
var filterComments []*ControllerFilterComments
for _, pf := range pfs {
var (
returnOnOutput bool
resetParams bool
)
if len(pf.params) >= 1 {
returnOnOutput = pf.params[0]
}
if len(pf.params) >= 2 {
resetParams = pf.params[1]
}
filterComments = append(filterComments, &ControllerFilterComments{
Filter: pf.filter,
Pattern: pf.pattern,
Pos: pf.pos,
ReturnOnOutput: returnOnOutput,
ResetParams: resetParams,
})
}
return filterComments
}
func buildMethodParams(funcParams []*ast.Field, pc *parsedComment) []*param.MethodParam {
result := make([]*param.MethodParam, 0, len(funcParams))
for _, fparam := range funcParams {
@ -177,26 +252,15 @@ func paramInPath(name, route string) bool {
var routeRegex = regexp.MustCompile(`@router\s+(\S+)(?:\s+\[(\S+)\])?`)
func parseComment(lines []*ast.Comment) (pc *parsedComment, err error) {
pc = &parsedComment{}
func parseComment(lines []*ast.Comment) (pcs []*parsedComment, err error) {
pcs = []*parsedComment{}
params := map[string]parsedParam{}
filters := []parsedFilter{}
imports := []parsedImport{}
for _, c := range lines {
t := strings.TrimSpace(strings.TrimLeft(c.Text, "//"))
if strings.HasPrefix(t, "@router") {
matches := routeRegex.FindStringSubmatch(t)
if len(matches) == 3 {
pc.routerPath = matches[1]
methods := matches[2]
if methods == "" {
pc.methods = []string{"get"}
//pc.hasGet = true
} else {
pc.methods = strings.Split(methods, ",")
//pc.hasGet = strings.Contains(methods, "get")
}
} else {
return nil, errors.New("Router information is missing")
}
} else if strings.HasPrefix(t, "@Param") {
if strings.HasPrefix(t, "@Param") {
pv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Param")))
if len(pv) < 4 {
logs.Error("Invalid @Param format. Needs at least 4 parameters")
@ -217,17 +281,99 @@ func parseComment(lines []*ast.Comment) (pc *parsedComment, err error) {
p.defValue = pv[3]
p.required, _ = strconv.ParseBool(pv[4])
}
if pc.params == nil {
pc.params = map[string]parsedParam{}
params[funcParamName] = p
}
}
for _, c := range lines {
t := strings.TrimSpace(strings.TrimLeft(c.Text, "//"))
if strings.HasPrefix(t, "@Import") {
iv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Import")))
if len(iv) == 0 || len(iv) > 2 {
logs.Error("Invalid @Import format. Only accepts 1 or 2 parameters")
continue
}
p := parsedImport{}
p.importPath = iv[0]
if len(iv) == 2 {
p.importAlias = iv[1]
}
imports = append(imports, p)
}
}
filterLoop:
for _, c := range lines {
t := strings.TrimSpace(strings.TrimLeft(c.Text, "//"))
if strings.HasPrefix(t, "@Filter") {
fv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Filter")))
if len(fv) < 3 {
logs.Error("Invalid @Filter format. Needs at least 3 parameters")
continue filterLoop
}
p := parsedFilter{}
p.pattern = fv[0]
posName := fv[1]
if pos, exists := routerHooks[posName]; exists {
p.pos = pos
} else {
logs.Error("Invalid @Filter pos: ", posName)
continue filterLoop
}
p.filter = fv[2]
fvParams := fv[3:]
for _, fvParam := range fvParams {
switch fvParam {
case "true":
p.params = append(p.params, true)
case "false":
p.params = append(p.params, false)
default:
logs.Error("Invalid @Filter param: ", fvParam)
continue filterLoop
}
}
filters = append(filters, p)
}
}
for _, c := range lines {
var pc = &parsedComment{}
pc.params = params
pc.filters = filters
pc.imports = imports
t := strings.TrimSpace(strings.TrimLeft(c.Text, "//"))
if strings.HasPrefix(t, "@router") {
t := strings.TrimSpace(strings.TrimLeft(c.Text, "//"))
matches := routeRegex.FindStringSubmatch(t)
if len(matches) == 3 {
pc.routerPath = matches[1]
methods := matches[2]
if methods == "" {
pc.methods = []string{"get"}
//pc.hasGet = true
} else {
pc.methods = strings.Split(methods, ",")
//pc.hasGet = strings.Contains(methods, "get")
}
pcs = append(pcs, pc)
} else {
return nil, errors.New("Router information is missing")
}
pc.params[funcParamName] = p
}
}
return
}
// direct copy from bee\g_docs.go
// analisys params return []string
// analysis params return []string
// @Param query form string true "The email for login"
// [query form string true "The email for login"]
func getparams(str string) []string {
@ -266,8 +412,9 @@ func genRouterCode(pkgRealpath string) {
os.Mkdir(getRouterDir(pkgRealpath), 0755)
logs.Info("generate router from comments")
var (
globalinfo string
sortKey []string
globalinfo string
globalimport string
sortKey []string
)
for k := range genInfoList {
sortKey = append(sortKey, k)
@ -285,6 +432,7 @@ func genRouterCode(pkgRealpath string) {
}
allmethod = strings.TrimRight(allmethod, ",") + "}"
}
params := "nil"
if len(c.Params) > 0 {
params = "[]map[string]string{"
@ -295,6 +443,7 @@ func genRouterCode(pkgRealpath string) {
}
params = strings.TrimRight(params, ",") + "}"
}
methodParams := "param.Make("
if len(c.MethodParams) > 0 {
lines := make([]string, 0, len(c.MethodParams))
@ -306,24 +455,72 @@ func genRouterCode(pkgRealpath string) {
",\n "
}
methodParams += ")"
imports := ""
if len(c.ImportComments) > 0 {
for _, i := range c.ImportComments {
var s string
if i.ImportAlias != "" {
s = fmt.Sprintf(`
%s "%s"`, i.ImportAlias, i.ImportPath)
} else {
s = fmt.Sprintf(`
"%s"`, i.ImportPath)
}
if !strings.Contains(globalimport, s) {
imports += s
}
}
}
filters := ""
if len(c.FilterComments) > 0 {
for _, f := range c.FilterComments {
filters += fmt.Sprintf(` &beego.ControllerFilter{
Pattern: "%s",
Pos: %s,
Filter: %s,
ReturnOnOutput: %v,
ResetParams: %v,
},`, f.Pattern, routerHooksMapping[f.Pos], f.Filter, f.ReturnOnOutput, f.ResetParams)
}
}
if filters == "" {
filters = "nil"
} else {
filters = fmt.Sprintf(`[]*beego.ControllerFilter{
%s
}`, filters)
}
globalimport += imports
globalinfo = globalinfo + `
beego.GlobalControllerRouter["` + k + `"] = append(beego.GlobalControllerRouter["` + k + `"],
beego.ControllerComments{
Method: "` + strings.TrimSpace(c.Method) + `",
` + "Router: `" + c.Router + "`" + `,
AllowHTTPMethods: ` + allmethod + `,
MethodParams: ` + methodParams + `,
Params: ` + params + `})
beego.GlobalControllerRouter["` + k + `"] = append(beego.GlobalControllerRouter["` + k + `"],
beego.ControllerComments{
Method: "` + strings.TrimSpace(c.Method) + `",
` + "Router: `" + c.Router + "`" + `,
AllowHTTPMethods: ` + allmethod + `,
MethodParams: ` + methodParams + `,
Filters: ` + filters + `,
Params: ` + params + `})
`
}
}
if globalinfo != "" {
f, err := os.Create(filepath.Join(getRouterDir(pkgRealpath), commentFilename))
if err != nil {
panic(err)
}
defer f.Close()
f.WriteString(strings.Replace(globalRouterTemplate, "{{.globalinfo}}", globalinfo, -1))
routersDir := AppConfig.DefaultString("routersdir", "routers")
content := strings.Replace(globalRouterTemplate, "{{.globalinfo}}", globalinfo, -1)
content = strings.Replace(content, "{{.routersDir}}", routersDir, -1)
content = strings.Replace(content, "{{.globalimport}}", globalimport, -1)
f.WriteString(content)
}
}
@ -379,7 +576,8 @@ func getpathTime(pkgRealpath string) (lastupdate int64, err error) {
func getRouterDir(pkgRealpath string) string {
dir := filepath.Dir(pkgRealpath)
for {
d := filepath.Join(dir, "routers")
routersDir := AppConfig.DefaultString("routersdir", "routers")
d := filepath.Join(dir, routersDir)
if utils.FileExists(d) {
return d
}

View File

@ -15,12 +15,12 @@
package beego
import (
"errors"
"fmt"
"net/http"
"path"
"path/filepath"
"reflect"
"runtime"
"strconv"
"strings"
"sync"
@ -50,28 +50,28 @@ const (
var (
// HTTPMETHOD list the supported http methods.
HTTPMETHOD = map[string]string{
"GET": "GET",
"POST": "POST",
"PUT": "PUT",
"DELETE": "DELETE",
"PATCH": "PATCH",
"OPTIONS": "OPTIONS",
"HEAD": "HEAD",
"TRACE": "TRACE",
"CONNECT": "CONNECT",
"MKCOL": "MKCOL",
"COPY": "COPY",
"MOVE": "MOVE",
"PROPFIND": "PROPFIND",
"PROPPATCH": "PROPPATCH",
"LOCK": "LOCK",
"UNLOCK": "UNLOCK",
HTTPMETHOD = map[string]bool{
"GET": true,
"POST": true,
"PUT": true,
"DELETE": true,
"PATCH": true,
"OPTIONS": true,
"HEAD": true,
"TRACE": true,
"CONNECT": true,
"MKCOL": true,
"COPY": true,
"MOVE": true,
"PROPFIND": true,
"PROPPATCH": true,
"LOCK": true,
"UNLOCK": true,
}
// these beego.Controller's methods shouldn't reflect to AutoRouter
exceptMethod = []string{"Init", "Prepare", "Finish", "Render", "RenderString",
"RenderBytes", "Redirect", "Abort", "StopRun", "UrlFor", "ServeJSON", "ServeJSONP",
"ServeXML", "Input", "ParseForm", "GetString", "GetStrings", "GetInt", "GetBool",
"ServeYAML", "ServeXML", "Input", "ParseForm", "GetString", "GetStrings", "GetInt", "GetBool",
"GetFloat", "GetFile", "SaveToFile", "StartSession", "SetSession", "GetSession",
"DelSession", "SessionRegenerateID", "DestroySession", "IsAjax", "GetSecureCookie",
"SetSecureCookie", "XsrfToken", "CheckXsrfCookie", "XsrfFormHtml",
@ -117,6 +117,7 @@ type ControllerInfo struct {
handler http.Handler
runFunction FilterFunc
routerType int
initialize func() ControllerInterface
methodParams []*param.MethodParam
}
@ -132,14 +133,15 @@ type ControllerRegister struct {
// NewControllerRegister returns a new ControllerRegister.
func NewControllerRegister() *ControllerRegister {
cr := &ControllerRegister{
return &ControllerRegister{
routers: make(map[string]*Tree),
policies: make(map[string]*Tree),
pool: sync.Pool{
New: func() interface{} {
return beecontext.NewContext()
},
},
}
cr.pool.New = func() interface{} {
return beecontext.NewContext()
}
return cr
}
// Add controller handler and pattern rules to ControllerRegister.
@ -169,7 +171,7 @@ func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInt
}
comma := strings.Split(colon[0], ",")
for _, m := range comma {
if _, ok := HTTPMETHOD[strings.ToUpper(m)]; m == "*" || ok {
if m == "*" || HTTPMETHOD[strings.ToUpper(m)] {
if val := reflectVal.MethodByName(colon[1]); val.IsValid() {
methods[strings.ToUpper(m)] = colon[1]
} else {
@ -187,15 +189,39 @@ func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInt
route.methods = methods
route.routerType = routerTypeBeego
route.controllerType = t
route.initialize = func() ControllerInterface {
vc := reflect.New(route.controllerType)
execController, ok := vc.Interface().(ControllerInterface)
if !ok {
panic("controller is not ControllerInterface")
}
elemVal := reflect.ValueOf(c).Elem()
elemType := reflect.TypeOf(c).Elem()
execElem := reflect.ValueOf(execController).Elem()
numOfFields := elemVal.NumField()
for i := 0; i < numOfFields; i++ {
fieldType := elemType.Field(i)
elemField := execElem.FieldByName(fieldType.Name)
if elemField.CanSet() {
fieldVal := elemVal.Field(i)
elemField.Set(fieldVal)
}
}
return execController
}
route.methodParams = methodParams
if len(methods) == 0 {
for _, m := range HTTPMETHOD {
for m := range HTTPMETHOD {
p.addToRouter(m, pattern, route)
}
} else {
for k := range methods {
if k == "*" {
for _, m := range HTTPMETHOD {
for m := range HTTPMETHOD {
p.addToRouter(m, pattern, route)
}
} else {
@ -252,6 +278,10 @@ func (p *ControllerRegister) Include(cList ...ControllerInterface) {
key := t.PkgPath() + ":" + t.Name()
if comm, ok := GlobalControllerRouter[key]; ok {
for _, a := range comm {
for _, f := range a.Filters {
p.InsertFilter(f.Pattern, f.Pos, f.Filter, f.ReturnOnOutput, f.ResetParams)
}
p.addWithMethodParams(a.Router, c, a.MethodParams, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method)
}
}
@ -337,7 +367,7 @@ func (p *ControllerRegister) Any(pattern string, f FilterFunc) {
// })
func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) {
method = strings.ToUpper(method)
if _, ok := HTTPMETHOD[method]; method != "*" && !ok {
if method != "*" && !HTTPMETHOD[method] {
panic("not support http method: " + method)
}
route := &ControllerInfo{}
@ -346,7 +376,7 @@ func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) {
route.runFunction = f
methods := make(map[string]string)
if method == "*" {
for _, val := range HTTPMETHOD {
for val := range HTTPMETHOD {
methods[val] = val
}
} else {
@ -355,7 +385,7 @@ func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) {
route.methods = methods
for k := range methods {
if k == "*" {
for _, m := range HTTPMETHOD {
for m := range HTTPMETHOD {
p.addToRouter(m, pattern, route)
}
} else {
@ -375,7 +405,7 @@ func (p *ControllerRegister) Handler(pattern string, h http.Handler, options ...
pattern = path.Join(pattern, "?:all(.*)")
}
}
for _, m := range HTTPMETHOD {
for m := range HTTPMETHOD {
p.addToRouter(m, pattern, route)
}
}
@ -410,7 +440,7 @@ func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface)
patternFix := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name))
patternFixInit := path.Join(prefix, controllerName, rt.Method(i).Name)
route.pattern = pattern
for _, m := range HTTPMETHOD {
for m := range HTTPMETHOD {
p.addToRouter(m, pattern, route)
p.addToRouter(m, patternInit, route)
p.addToRouter(m, patternFix, route)
@ -449,8 +479,7 @@ func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter Filter
// add Filter into
func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) (err error) {
if pos < BeforeStatic || pos > FinishRouter {
err = fmt.Errorf("can not find your filter position")
return
return errors.New("can not find your filter position")
}
p.enableFilter = true
p.filters[pos] = append(p.filters[pos], mr)
@ -480,10 +509,10 @@ func (p *ControllerRegister) URLFor(endpoint string, values ...interface{}) stri
}
}
}
controllName := strings.Join(paths[:len(paths)-1], "/")
controllerName := strings.Join(paths[:len(paths)-1], "/")
methodName := paths[len(paths)-1]
for m, t := range p.routers {
ok, url := p.geturl(t, "/", controllName, methodName, params, m)
ok, url := p.getURL(t, "/", controllerName, methodName, params, m)
if ok {
return url
}
@ -491,17 +520,17 @@ func (p *ControllerRegister) URLFor(endpoint string, values ...interface{}) stri
return ""
}
func (p *ControllerRegister) geturl(t *Tree, url, controllName, methodName string, params map[string]string, httpMethod string) (bool, string) {
func (p *ControllerRegister) getURL(t *Tree, url, controllerName, methodName string, params map[string]string, httpMethod string) (bool, string) {
for _, subtree := range t.fixrouters {
u := path.Join(url, subtree.prefix)
ok, u := p.geturl(subtree, u, controllName, methodName, params, httpMethod)
ok, u := p.getURL(subtree, u, controllerName, methodName, params, httpMethod)
if ok {
return ok, u
}
}
if t.wildcard != nil {
u := path.Join(url, urlPlaceholder)
ok, u := p.geturl(t.wildcard, u, controllName, methodName, params, httpMethod)
ok, u := p.getURL(t.wildcard, u, controllerName, methodName, params, httpMethod)
if ok {
return ok, u
}
@ -509,9 +538,9 @@ func (p *ControllerRegister) geturl(t *Tree, url, controllName, methodName strin
for _, l := range t.leaves {
if c, ok := l.runObject.(*ControllerInfo); ok {
if c.routerType == routerTypeBeego &&
strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), controllName) {
strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), controllerName) {
find := false
if _, ok := HTTPMETHOD[strings.ToUpper(methodName)]; ok {
if HTTPMETHOD[strings.ToUpper(methodName)] {
if len(c.methods) == 0 {
find = true
} else if m, ok := c.methods[strings.ToUpper(methodName)]; ok && m == strings.ToUpper(methodName) {
@ -548,18 +577,18 @@ func (p *ControllerRegister) geturl(t *Tree, url, controllName, methodName strin
}
}
}
canskip := false
canSkip := false
for _, v := range l.wildcards {
if v == ":" {
canskip = true
canSkip = true
continue
}
if u, ok := params[v]; ok {
delete(params, v)
url = strings.Replace(url, urlPlaceholder, u, 1)
} else {
if canskip {
canskip = false
if canSkip {
canSkip = false
continue
}
return false, ""
@ -568,27 +597,27 @@ func (p *ControllerRegister) geturl(t *Tree, url, controllName, methodName strin
return true, url + toURL(params)
}
var i int
var startreg bool
regurl := ""
var startReg bool
regURL := ""
for _, v := range strings.Trim(l.regexps.String(), "^$") {
if v == '(' {
startreg = true
startReg = true
continue
} else if v == ')' {
startreg = false
startReg = false
if v, ok := params[l.wildcards[i]]; ok {
delete(params, l.wildcards[i])
regurl = regurl + v
regURL = regURL + v
i++
} else {
break
}
} else if !startreg {
regurl = string(append([]rune(regurl), v))
} else if !startReg {
regURL = string(append([]rune(regURL), v))
}
}
if l.regexps.MatchString(regurl) {
ps := strings.Split(regurl, "/")
if l.regexps.MatchString(regURL) {
ps := strings.Split(regURL, "/")
for _, p := range ps {
url = strings.Replace(url, urlPlaceholder, p, 1)
}
@ -659,8 +688,8 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
}
// filter wrong http method
if _, ok := HTTPMETHOD[r.Method]; !ok {
http.Error(rw, "Method Not Allowed", 405)
if !HTTPMETHOD[r.Method] {
exception("405", context)
goto Admin
}
@ -749,7 +778,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
runRouter = routerInfo.controllerType
methodParams = routerInfo.methodParams
method := r.Method
if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodPost {
if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodPut {
method = http.MethodPut
}
if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodDelete {
@ -768,14 +797,20 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
// also defined runRouter & runMethod from filter
if !isRunnable {
//Invoke the request handler
vc := reflect.New(runRouter)
execController, ok := vc.Interface().(ControllerInterface)
if !ok {
panic("controller is not ControllerInterface")
var execController ControllerInterface
if routerInfo != nil && routerInfo.initialize != nil {
execController = routerInfo.initialize()
} else {
vc := reflect.New(runRouter)
var ok bool
execController, ok = vc.Interface().(ControllerInterface)
if !ok {
panic("controller is not ControllerInterface")
}
}
//call the controller init function
execController.Init(context, runRouter.Name(), runMethod, vc.Interface())
execController.Init(context, runRouter.Name(), runMethod, execController)
//call prepare function
execController.Prepare()
@ -808,8 +843,11 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
execController.Patch()
case http.MethodOptions:
execController.Options()
case http.MethodTrace:
execController.Trace()
default:
if !execController.HandlerFunc(runMethod) {
vc := reflect.ValueOf(execController)
method := vc.MethodByName(runMethod)
in := param.ConvertParams(methodParams, method.Type(), context)
out := method.Call(in)
@ -846,59 +884,46 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
Admin:
//admin module record QPS
statusCode := context.ResponseWriter.Status
if statusCode == 0 {
statusCode = 200
}
LogAccess(context, &startTime, statusCode)
timeDur := time.Since(startTime)
context.ResponseWriter.Elapsed = timeDur
if BConfig.Listen.EnableAdmin {
timeDur := time.Since(startTime)
pattern := ""
if routerInfo != nil {
pattern = routerInfo.pattern
}
statusCode := context.ResponseWriter.Status
if statusCode == 0 {
statusCode = 200
}
if FilterMonitorFunc(r.Method, r.URL.Path, timeDur, pattern, statusCode) {
routerName := ""
if runRouter != nil {
go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, runRouter.Name(), timeDur)
} else {
go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, "", timeDur)
routerName = runRouter.Name()
}
go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, routerName, timeDur)
}
}
if BConfig.RunMode == DEV || BConfig.Log.AccessLogs {
timeDur := time.Since(startTime)
var devInfo string
statusCode := context.ResponseWriter.Status
if statusCode == 0 {
statusCode = 200
if BConfig.RunMode == DEV && !BConfig.Log.AccessLogs {
match := map[bool]string{true: "match", false: "nomatch"}
devInfo := fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s",
context.Input.IP(),
logs.ColorByStatus(statusCode), statusCode, logs.ResetColor(),
timeDur.String(),
match[findRouter],
logs.ColorByMethod(r.Method), r.Method, logs.ResetColor(),
r.URL.Path)
if routerInfo != nil {
devInfo += fmt.Sprintf(" r:%s", routerInfo.pattern)
}
iswin := (runtime.GOOS == "windows")
statusColor := logs.ColorByStatus(iswin, statusCode)
methodColor := logs.ColorByMethod(iswin, r.Method)
resetColor := logs.ColorByMethod(iswin, "")
if findRouter {
if routerInfo != nil {
devInfo = fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s r:%s", context.Input.IP(), statusColor, statusCode,
resetColor, timeDur.String(), "match", methodColor, r.Method, resetColor, r.URL.Path,
routerInfo.pattern)
} else {
devInfo = fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s", context.Input.IP(), statusColor, statusCode, resetColor,
timeDur.String(), "match", methodColor, r.Method, resetColor, r.URL.Path)
}
} else {
devInfo = fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s", context.Input.IP(), statusColor, statusCode, resetColor,
timeDur.String(), "nomatch", methodColor, r.Method, resetColor, r.URL.Path)
}
if iswin {
logs.W32Debug(devInfo)
} else {
logs.Debug(devInfo)
}
logs.Debug(devInfo)
}
// Call WriteHeader if status code has been set changed
if context.Output.Status != 0 {
context.ResponseWriter.WriteHeader(context.Output.Status)
@ -914,7 +939,7 @@ func (p *ControllerRegister) handleParamResponse(context *beecontext.Context, ex
context.RenderMethodResult(resultValue)
}
}
if !context.ResponseWriter.Started && context.Output.Status == 0 {
if !context.ResponseWriter.Started && len(results) > 0 && context.Output.Status == 0 {
context.Output.SetStatus(200)
}
}
@ -945,3 +970,39 @@ func toURL(params map[string]string) string {
}
return strings.TrimRight(u, "&")
}
// LogAccess logging info HTTP Access
func LogAccess(ctx *beecontext.Context, startTime *time.Time, statusCode int) {
//Skip logging if AccessLogs config is false
if !BConfig.Log.AccessLogs {
return
}
//Skip logging static requests unless EnableStaticLogs config is true
if !BConfig.Log.EnableStaticLogs && DefaultAccessLogFilter.Filter(ctx) {
return
}
var (
requestTime time.Time
elapsedTime time.Duration
r = ctx.Request
)
if startTime != nil {
requestTime = *startTime
elapsedTime = time.Since(*startTime)
}
record := &logs.AccessLogRecord{
RemoteAddr: ctx.Input.IP(),
RequestTime: requestTime,
RequestMethod: r.Method,
Request: fmt.Sprintf("%s %s %s", r.Method, r.RequestURI, r.Proto),
ServerProtocol: r.Proto,
Host: r.Host,
Status: statusCode,
ElapsedTime: elapsedTime,
HTTPReferrer: r.Header.Get("Referer"),
HTTPUserAgent: r.Header.Get("User-Agent"),
RemoteUser: r.Header.Get("Remote-User"),
BodyBytesSent: 0, //@todo this one is missing!
}
logs.AccessLog(record, BConfig.Log.AccessLogsFormat)
}

View File

@ -14,9 +14,9 @@
// Package redis for session provider
//
// depend on github.com/garyburd/redigo/redis
// depend on github.com/gomodule/redigo/redis
//
// go install github.com/garyburd/redigo/redis
// go install github.com/gomodule/redigo/redis
//
// Usage:
// import(
@ -24,10 +24,10 @@
// "github.com/astaxie/beego/session"
// )
//
// func init() {
// globalSessions, _ = session.NewManager("redis", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:7070"}``)
// go globalSessions.GC()
// }
// func init() {
// globalSessions, _ = session.NewManager("redis", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:7070"}``)
// go globalSessions.GC()
// }
//
// more docs: http://beego.me/docs/module/session.md
package redis
@ -37,10 +37,11 @@ import (
"strconv"
"strings"
"sync"
"time"
"github.com/astaxie/beego/session"
"github.com/garyburd/redigo/redis"
"github.com/gomodule/redigo/redis"
)
var redispder = &Provider{}
@ -118,8 +119,8 @@ type Provider struct {
}
// SessionInit init redis session
// savepath like redis server addr,pool size,password,dbnum
// e.g. 127.0.0.1:6379,100,astaxie,0
// savepath like redis server addr,pool size,password,dbnum,IdleTimeout second
// e.g. 127.0.0.1:6379,100,astaxie,0,30
func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error {
rp.maxlifetime = maxlifetime
configs := strings.Split(savePath, ",")
@ -149,24 +150,39 @@ func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error {
} else {
rp.dbNum = 0
}
rp.poollist = redis.NewPool(func() (redis.Conn, error) {
c, err := redis.Dial("tcp", rp.savePath)
if err != nil {
return nil, err
var idleTimeout time.Duration = 0
if len(configs) > 4 {
timeout, err := strconv.Atoi(configs[4])
if err == nil && timeout > 0 {
idleTimeout = time.Duration(timeout) * time.Second
}
if rp.password != "" {
if _, err = c.Do("AUTH", rp.password); err != nil {
c.Close()
}
rp.poollist = &redis.Pool{
Dial: func() (redis.Conn, error) {
c, err := redis.Dial("tcp", rp.savePath)
if err != nil {
return nil, err
}
}
_, err = c.Do("SELECT", rp.dbNum)
if err != nil {
c.Close()
return nil, err
}
return c, err
}, rp.poolsize)
if rp.password != "" {
if _, err = c.Do("AUTH", rp.password); err != nil {
c.Close()
return nil, err
}
}
// some redis proxy such as twemproxy is not support select command
if rp.dbNum > 0 {
_, err = c.Do("SELECT", rp.dbNum)
if err != nil {
c.Close()
return nil, err
}
}
return c, err
},
MaxIdle: rp.poolsize,
}
rp.poollist.IdleTimeout = idleTimeout
return rp.poollist.Get().Err()
}

View File

@ -19,8 +19,10 @@ import (
"io/ioutil"
"net/http"
"os"
"errors"
"path"
"path/filepath"
"strings"
"sync"
"time"
)
@ -78,6 +80,8 @@ func (fs *FileSessionStore) SessionID() string {
// SessionRelease Write file session to local file with Gob string
func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) {
filepder.lock.Lock()
defer filepder.lock.Unlock()
b, err := EncodeGob(fs.values)
if err != nil {
SLogger.Println(err)
@ -125,6 +129,12 @@ func (fp *FileProvider) SessionInit(maxlifetime int64, savePath string) error {
// if file is not exist, create it.
// the file path is generated from sid string.
func (fp *FileProvider) SessionRead(sid string) (Store, error) {
if strings.ContainsAny(sid, "./") {
return nil, nil
}
if len(sid) < 2 {
return nil, errors.New("length of the sid is less than 2")
}
filepder.lock.Lock()
defer filepder.lock.Unlock()
@ -164,7 +174,7 @@ func (fp *FileProvider) SessionRead(sid string) (Store, error) {
}
// SessionExist Check file session exist.
// it checkes the file named from sid exist or not.
// it checks the file named from sid exist or not.
func (fp *FileProvider) SessionExist(sid string) bool {
filepder.lock.Lock()
defer filepder.lock.Unlock()

View File

@ -149,7 +149,7 @@ func decodeCookie(block cipher.Block, hashKey, name, value string, gcmaxlifetime
// 2. Verify MAC. Value is "date|value|mac".
parts := bytes.SplitN(b, []byte("|"), 3)
if len(parts) != 3 {
return nil, errors.New("Decode: invalid value %v")
return nil, errors.New("Decode: invalid value format")
}
b = append([]byte(name+"|"), b[:len(b)-len(parts[2])]...)

View File

@ -81,6 +81,15 @@ func Register(name string, provide Provider) {
provides[name] = provide
}
//GetProvider
func GetProvider(name string) (Provider, error) {
provider, ok := provides[name]
if !ok {
return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", name)
}
return provider, nil
}
// ManagerConfig define the session config
type ManagerConfig struct {
CookieName string `json:"cookieName"`
@ -96,6 +105,7 @@ type ManagerConfig struct {
EnableSidInHTTPHeader bool `json:"EnableSidInHTTPHeader"`
SessionNameInHTTPHeader string `json:"SessionNameInHTTPHeader"`
EnableSidInURLQuery bool `json:"EnableSidInURLQuery"`
SessionIDPrefix string `json:"sessionIDPrefix"`
}
// Manager contains Provider and its configuration.
@ -153,6 +163,11 @@ func NewManager(provideName string, cf *ManagerConfig) (*Manager, error) {
}, nil
}
// GetProvider return current manager's provider
func (manager *Manager) GetProvider() Provider {
return manager.provider
}
// getSid retrieves session identifier from HTTP Request.
// First try to retrieve id by reading from cookie, session cookie name is configurable,
// if not exist, then retrieve id from querying parameters.
@ -331,7 +346,7 @@ func (manager *Manager) sessionID() (string, error) {
if n != len(b) || err != nil {
return "", fmt.Errorf("Could not successfully read from the system CSPRNG")
}
return hex.EncodeToString(b), nil
return manager.config.SessionIDPrefix + hex.EncodeToString(b), nil
}
// Set cookie with https.

View File

@ -74,7 +74,7 @@ func serverStaticRouter(ctx *context.Context) {
if enableCompress {
acceptEncoding = context.ParseEncoding(ctx.Request)
}
b, n, sch, err := openFile(filePath, fileInfo, acceptEncoding)
b, n, sch, reader, err := openFile(filePath, fileInfo, acceptEncoding)
if err != nil {
if BConfig.RunMode == DEV {
logs.Warn("Can't compress the file:", filePath, err)
@ -89,47 +89,53 @@ func serverStaticRouter(ctx *context.Context) {
ctx.Output.Header("Content-Length", strconv.FormatInt(sch.size, 10))
}
http.ServeContent(ctx.ResponseWriter, ctx.Request, filePath, sch.modTime, sch)
http.ServeContent(ctx.ResponseWriter, ctx.Request, filePath, sch.modTime, reader)
}
type serveContentHolder struct {
*bytes.Reader
data []byte
modTime time.Time
size int64
encoding string
}
type serveContentReader struct {
*bytes.Reader
}
var (
staticFileMap = make(map[string]*serveContentHolder)
mapLock sync.RWMutex
)
func openFile(filePath string, fi os.FileInfo, acceptEncoding string) (bool, string, *serveContentHolder, error) {
func openFile(filePath string, fi os.FileInfo, acceptEncoding string) (bool, string, *serveContentHolder, *serveContentReader, error) {
mapKey := acceptEncoding + ":" + filePath
mapLock.RLock()
mapFile := staticFileMap[mapKey]
mapLock.RUnlock()
if isOk(mapFile, fi) {
return mapFile.encoding != "", mapFile.encoding, mapFile, nil
reader := &serveContentReader{Reader: bytes.NewReader(mapFile.data)}
return mapFile.encoding != "", mapFile.encoding, mapFile, reader, nil
}
mapLock.Lock()
defer mapLock.Unlock()
if mapFile = staticFileMap[mapKey]; !isOk(mapFile, fi) {
file, err := os.Open(filePath)
if err != nil {
return false, "", nil, err
return false, "", nil, nil, err
}
defer file.Close()
var bufferWriter bytes.Buffer
_, n, err := context.WriteFile(acceptEncoding, &bufferWriter, file)
if err != nil {
return false, "", nil, err
return false, "", nil, nil, err
}
mapFile = &serveContentHolder{Reader: bytes.NewReader(bufferWriter.Bytes()), modTime: fi.ModTime(), size: int64(bufferWriter.Len()), encoding: n}
mapFile = &serveContentHolder{data: bufferWriter.Bytes(), modTime: fi.ModTime(), size: int64(bufferWriter.Len()), encoding: n}
staticFileMap[mapKey] = mapFile
}
return mapFile.encoding != "", mapFile.encoding, mapFile, nil
reader := &serveContentReader{Reader: bytes.NewReader(mapFile.data)}
return mapFile.encoding != "", mapFile.encoding, mapFile, reader, nil
}
func isOk(s *serveContentHolder, fi os.FileInfo) bool {
@ -172,7 +178,7 @@ func searchFile(ctx *context.Context) (string, os.FileInfo, error) {
if !strings.Contains(requestPath, prefix) {
continue
}
if len(requestPath) > len(prefix) && requestPath[len(prefix)] != '/' {
if prefix != "/" && len(requestPath) > len(prefix) && requestPath[len(prefix)] != '/' {
continue
}
filePath := path.Join(staticDir, requestPath[len(prefix):])

View File

@ -20,6 +20,7 @@ import (
"html/template"
"io"
"io/ioutil"
"net/http"
"os"
"path/filepath"
"regexp"
@ -37,9 +38,10 @@ var (
beeViewPathTemplates = make(map[string]map[string]*template.Template)
templatesLock sync.RWMutex
// beeTemplateExt stores the template extension which will build
beeTemplateExt = []string{"tpl", "html"}
beeTemplateExt = []string{"tpl", "html", "gohtml"}
// beeTemplatePreprocessors stores associations of extension -> preprocessor handler
beeTemplateEngines = map[string]templatePreProcessor{}
beeTemplateFS = defaultFSFunc
)
// ExecuteTemplate applies the template with name to the specified data object,
@ -181,12 +183,17 @@ func lockViewPaths() {
// BuildTemplate will build all template files in a directory.
// it makes beego can render any template file in view directory.
func BuildTemplate(dir string, files ...string) error {
if _, err := os.Stat(dir); err != nil {
var err error
fs := beeTemplateFS()
f, err := fs.Open(dir)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return errors.New("dir open err")
}
defer f.Close()
beeTemplates, ok := beeViewPathTemplates[dir]
if !ok {
panic("Unknown view path: " + dir)
@ -195,11 +202,11 @@ func BuildTemplate(dir string, files ...string) error {
root: dir,
files: make(map[string][]string),
}
err := filepath.Walk(dir, func(path string, f os.FileInfo, err error) error {
err = Walk(fs, dir, func(path string, f os.FileInfo, err error) error {
return self.visit(path, f, err)
})
if err != nil {
fmt.Printf("filepath.Walk() returned %v\n", err)
fmt.Printf("Walk() returned %v\n", err)
return err
}
buildAllFiles := len(files) == 0
@ -210,17 +217,18 @@ func BuildTemplate(dir string, files ...string) error {
ext := filepath.Ext(file)
var t *template.Template
if len(ext) == 0 {
t, err = getTemplate(self.root, file, v...)
t, err = getTemplate(self.root, fs, file, v...)
} else if fn, ok := beeTemplateEngines[ext[1:]]; ok {
t, err = fn(self.root, file, beegoTplFuncMap)
} else {
t, err = getTemplate(self.root, file, v...)
t, err = getTemplate(self.root, fs, file, v...)
}
if err != nil {
logs.Error("parse template err:", file, err)
} else {
beeTemplates[file] = t
templatesLock.Unlock()
return err
}
beeTemplates[file] = t
templatesLock.Unlock()
}
}
@ -228,20 +236,23 @@ func BuildTemplate(dir string, files ...string) error {
return nil
}
func getTplDeep(root, file, parent string, t *template.Template) (*template.Template, [][]string, error) {
func getTplDeep(root string, fs http.FileSystem, file string, parent string, t *template.Template) (*template.Template, [][]string, error) {
var fileAbsPath string
var rParent string
if filepath.HasPrefix(file, "../") {
var err error
if strings.HasPrefix(file, "../") {
rParent = filepath.Join(filepath.Dir(parent), file)
fileAbsPath = filepath.Join(root, filepath.Dir(parent), file)
} else {
rParent = file
fileAbsPath = filepath.Join(root, file)
}
if e := utils.FileExists(fileAbsPath); !e {
f, err := fs.Open(fileAbsPath)
if err != nil {
panic("can't find template file:" + file)
}
data, err := ioutil.ReadFile(fileAbsPath)
defer f.Close()
data, err := ioutil.ReadAll(f)
if err != nil {
return nil, [][]string{}, err
}
@ -260,7 +271,7 @@ func getTplDeep(root, file, parent string, t *template.Template) (*template.Temp
if !HasTemplateExt(m[1]) {
continue
}
_, _, err = getTplDeep(root, m[1], rParent, t)
_, _, err = getTplDeep(root, fs, m[1], rParent, t)
if err != nil {
return nil, [][]string{}, err
}
@ -269,14 +280,14 @@ func getTplDeep(root, file, parent string, t *template.Template) (*template.Temp
return t, allSub, nil
}
func getTemplate(root, file string, others ...string) (t *template.Template, err error) {
func getTemplate(root string, fs http.FileSystem, file string, others ...string) (t *template.Template, err error) {
t = template.New(file).Delims(BConfig.WebConfig.TemplateLeft, BConfig.WebConfig.TemplateRight).Funcs(beegoTplFuncMap)
var subMods [][]string
t, subMods, err = getTplDeep(root, file, "", t)
t, subMods, err = getTplDeep(root, fs, file, "", t)
if err != nil {
return nil, err
}
t, err = _getTemplate(t, root, subMods, others...)
t, err = _getTemplate(t, root, fs, subMods, others...)
if err != nil {
return nil, err
@ -284,7 +295,7 @@ func getTemplate(root, file string, others ...string) (t *template.Template, err
return
}
func _getTemplate(t0 *template.Template, root string, subMods [][]string, others ...string) (t *template.Template, err error) {
func _getTemplate(t0 *template.Template, root string, fs http.FileSystem, subMods [][]string, others ...string) (t *template.Template, err error) {
t = t0
for _, m := range subMods {
if len(m) == 2 {
@ -296,11 +307,11 @@ func _getTemplate(t0 *template.Template, root string, subMods [][]string, others
for _, otherFile := range others {
if otherFile == m[1] {
var subMods1 [][]string
t, subMods1, err = getTplDeep(root, otherFile, "", t)
t, subMods1, err = getTplDeep(root, fs, otherFile, "", t)
if err != nil {
logs.Trace("template parse file err:", err)
} else if len(subMods1) > 0 {
t, err = _getTemplate(t, root, subMods1, others...)
t, err = _getTemplate(t, root, fs, subMods1, others...)
}
break
}
@ -309,8 +320,16 @@ func _getTemplate(t0 *template.Template, root string, subMods [][]string, others
for _, otherFile := range others {
var data []byte
fileAbsPath := filepath.Join(root, otherFile)
data, err = ioutil.ReadFile(fileAbsPath)
f, err := fs.Open(fileAbsPath)
if err != nil {
f.Close()
logs.Trace("template file parse error, not success open file:", err)
continue
}
data, err = ioutil.ReadAll(f)
f.Close()
if err != nil {
logs.Trace("template file parse error, not success read file:", err)
continue
}
reg := regexp.MustCompile(BConfig.WebConfig.TemplateLeft + "[ ]*define[ ]+\"([^\"]+)\"")
@ -318,11 +337,14 @@ func _getTemplate(t0 *template.Template, root string, subMods [][]string, others
for _, sub := range allSub {
if len(sub) == 2 && sub[1] == m[1] {
var subMods1 [][]string
t, subMods1, err = getTplDeep(root, otherFile, "", t)
t, subMods1, err = getTplDeep(root, fs, otherFile, "", t)
if err != nil {
logs.Trace("template parse file err:", err)
} else if len(subMods1) > 0 {
t, err = _getTemplate(t, root, subMods1, others...)
t, err = _getTemplate(t, root, fs, subMods1, others...)
if err != nil {
logs.Trace("template parse file err:", err)
}
}
break
}
@ -334,6 +356,17 @@ func _getTemplate(t0 *template.Template, root string, subMods [][]string, others
return
}
type templateFSFunc func() http.FileSystem
func defaultFSFunc() http.FileSystem {
return FileSystem{}
}
// SetTemplateFSFunc set default filesystem function
func SetTemplateFSFunc(fnt templateFSFunc) {
beeTemplateFS = fnt
}
// SetViewsPath sets view directory path in beego application.
func SetViewsPath(path string) *App {
BConfig.WebConfig.ViewsPath = path

View File

@ -17,6 +17,7 @@ package beego
import (
"errors"
"fmt"
"html"
"html/template"
"net/url"
"reflect"
@ -54,21 +55,21 @@ func Substr(s string, start, length int) string {
// HTML2str returns escaping text convert from html.
func HTML2str(html string) string {
re, _ := regexp.Compile(`\<[\S\s]+?\>`)
re := regexp.MustCompile(`\<[\S\s]+?\>`)
html = re.ReplaceAllStringFunc(html, strings.ToLower)
//remove STYLE
re, _ = regexp.Compile(`\<style[\S\s]+?\</style\>`)
re = regexp.MustCompile(`\<style[\S\s]+?\</style\>`)
html = re.ReplaceAllString(html, "")
//remove SCRIPT
re, _ = regexp.Compile(`\<script[\S\s]+?\</script\>`)
re = regexp.MustCompile(`\<script[\S\s]+?\</script\>`)
html = re.ReplaceAllString(html, "")
re, _ = regexp.Compile(`\<[\S\s]+?\>`)
re = regexp.MustCompile(`\<[\S\s]+?\>`)
html = re.ReplaceAllString(html, "\n")
re, _ = regexp.Compile(`\s{2,}`)
re = regexp.MustCompile(`\s{2,}`)
html = re.ReplaceAllString(html, "\n")
return strings.TrimSpace(html)
@ -171,7 +172,7 @@ func GetConfig(returnType, key string, defaultVal interface{}) (value interface{
case "DIY":
value, err = AppConfig.DIY(key)
default:
err = errors.New("Config keys must be of type String, Bool, Int, Int64, Float, or DIY")
err = errors.New("config keys must be of type String, Bool, Int, Int64, Float, or DIY")
}
if err != nil {
@ -207,14 +208,12 @@ func Htmlquote(text string) string {
'&lt;&#39;&amp;&quot;&gt;'
*/
text = strings.Replace(text, "&", "&amp;", -1) // Must be done first!
text = strings.Replace(text, "<", "&lt;", -1)
text = strings.Replace(text, ">", "&gt;", -1)
text = strings.Replace(text, "'", "&#39;", -1)
text = strings.Replace(text, "\"", "&quot;", -1)
text = strings.Replace(text, "“", "&ldquo;", -1)
text = strings.Replace(text, "”", "&rdquo;", -1)
text = strings.Replace(text, " ", "&nbsp;", -1)
text = html.EscapeString(text)
text = strings.NewReplacer(
``, "&ldquo;",
``, "&rdquo;",
` `, "&nbsp;",
).Replace(text)
return strings.TrimSpace(text)
}
@ -228,17 +227,7 @@ func Htmlunquote(text string) string {
'<\\'&">'
*/
// strings.Replace(s, old, new, n)
// 在s字符串中把old字符串替换为new字符串n表示替换的次数小于0表示全部替换
text = strings.Replace(text, "&nbsp;", " ", -1)
text = strings.Replace(text, "&rdquo;", "”", -1)
text = strings.Replace(text, "&ldquo;", "“", -1)
text = strings.Replace(text, "&quot;", "\"", -1)
text = strings.Replace(text, "&#39;", "'", -1)
text = strings.Replace(text, "&gt;", ">", -1)
text = strings.Replace(text, "&lt;", "<", -1)
text = strings.Replace(text, "&amp;", "&", -1) // Must be done last!
text = html.UnescapeString(text)
return strings.TrimSpace(text)
}
@ -308,9 +297,21 @@ func parseFormToStruct(form url.Values, objT reflect.Type, objV reflect.Value) e
tag = tags[0]
}
value := form.Get(tag)
if len(value) == 0 {
continue
formValues := form[tag]
var value string
if len(formValues) == 0 {
defaultValue := fieldT.Tag.Get("default")
if defaultValue != "" {
value = defaultValue
} else {
continue
}
}
if len(formValues) == 1 {
value = formValues[0]
if value == "" {
continue
}
}
switch fieldT.Type.Kind() {
@ -360,6 +361,8 @@ func parseFormToStruct(form url.Values, objT reflect.Type, objV reflect.Value) e
if len(value) >= 25 {
value = value[:25]
t, err = time.ParseInLocation(time.RFC3339, value, time.Local)
} else if strings.HasSuffix(strings.ToUpper(value), "Z") {
t, err = time.ParseInLocation(time.RFC3339, value, time.Local)
} else if len(value) >= 19 {
if strings.Contains(value, "T") {
value = value[:19]
@ -703,7 +706,7 @@ func ge(arg1, arg2 interface{}) (bool, error) {
// MapGet getting value from map by keys
// usage:
// Data["m"] = map[string]interface{} {
// Data["m"] = M{
// "a": 1,
// "1": map[string]float64{
// "c": 4,

View File

@ -20,6 +20,7 @@ import (
"sort"
"strconv"
"strings"
"sync"
"time"
)
@ -32,6 +33,7 @@ type bounds struct {
// The bounds for each field.
var (
AdminTaskList map[string]Tasker
taskLock sync.Mutex
stop chan bool
changed chan bool
isstart bool
@ -389,6 +391,8 @@ func dayMatches(s *Schedule, t time.Time) bool {
// StartTask start all tasks
func StartTask() {
taskLock.Lock()
defer taskLock.Unlock()
if isstart {
//If already started no need to start another goroutine.
return
@ -428,6 +432,9 @@ func run() {
continue
case <-changed:
now = time.Now().Local()
for _, t := range AdminTaskList {
t.SetNext(now)
}
continue
case <-stop:
return
@ -437,6 +444,8 @@ func run() {
// StopTask stop all tasks
func StopTask() {
taskLock.Lock()
defer taskLock.Unlock()
if isstart {
isstart = false
stop <- true
@ -446,6 +455,9 @@ func StopTask() {
// AddTask add task with name
func AddTask(taskname string, t Tasker) {
taskLock.Lock()
defer taskLock.Unlock()
t.SetNext(time.Now().Local())
AdminTaskList[taskname] = t
if isstart {
changed <- true
@ -454,6 +466,8 @@ func AddTask(taskname string, t Tasker) {
// DeleteTask delete task with name
func DeleteTask(taskname string) {
taskLock.Lock()
defer taskLock.Unlock()
delete(AdminTaskList, taskname)
if isstart {
changed <- true

View File

@ -28,7 +28,7 @@ var (
)
// Tree has three elements: FixRouter/wildcard/leaves
// fixRouter sotres Fixed Router
// fixRouter stores Fixed Router
// wildcard stores params
// leaves store the endpoint information
type Tree struct {

View File

@ -162,7 +162,7 @@ func (e *Email) Bytes() ([]byte, error) {
// AttachFile Add attach file to the send mail
func (e *Email) AttachFile(args ...string) (a *Attachment, err error) {
if len(args) < 1 && len(args) > 2 {
if len(args) < 1 || len(args) > 2 { // change && to ||
err = errors.New("Must specify a file name and number of parameters can not exceed at least two")
return
}
@ -183,7 +183,7 @@ func (e *Email) AttachFile(args ...string) (a *Attachment, err error) {
// Attach is used to attach content from an io.Reader to the email.
// Parameters include an io.Reader, the desired filename for the attachment, and the Content-Type.
func (e *Email) Attach(r io.Reader, filename string, args ...string) (a *Attachment, err error) {
if len(args) < 1 && len(args) > 2 {
if len(args) < 1 || len(args) > 2 { // change && to ||
err = errors.New("Must specify the file type and number of parameters can not exceed at least two")
return
}

View File

@ -3,19 +3,78 @@ package utils
import (
"os"
"path/filepath"
"regexp"
"runtime"
"strconv"
"strings"
)
// GetGOPATHs returns all paths in GOPATH variable.
func GetGOPATHs() []string {
gopath := os.Getenv("GOPATH")
if gopath == "" && strings.Compare(runtime.Version(), "go1.8") >= 0 {
if gopath == "" && compareGoVersion(runtime.Version(), "go1.8") >= 0 {
gopath = defaultGOPATH()
}
return filepath.SplitList(gopath)
}
func compareGoVersion(a, b string) int {
reg := regexp.MustCompile("^\\d*")
a = strings.TrimPrefix(a, "go")
b = strings.TrimPrefix(b, "go")
versionsA := strings.Split(a, ".")
versionsB := strings.Split(b, ".")
for i := 0; i < len(versionsA) && i < len(versionsB); i++ {
versionA := versionsA[i]
versionB := versionsB[i]
vA, err := strconv.Atoi(versionA)
if err != nil {
str := reg.FindString(versionA)
if str != "" {
vA, _ = strconv.Atoi(str)
} else {
vA = -1
}
}
vB, err := strconv.Atoi(versionB)
if err != nil {
str := reg.FindString(versionB)
if str != "" {
vB, _ = strconv.Atoi(str)
} else {
vB = -1
}
}
if vA > vB {
// vA = 12, vB = 8
return 1
} else if vA < vB {
// vA = 6, vB = 8
return -1
} else if vA == -1 {
// vA = rc1, vB = rc3
return strings.Compare(versionA, versionB)
}
// vA = vB = 8
continue
}
if len(versionsA) > len(versionsB) {
return 1
} else if len(versionsA) == len(versionsB) {
return 0
}
return -1
}
func defaultGOPATH() string {
env := "HOME"
if runtime.GOOS == "windows" {

View File

@ -112,7 +112,7 @@ type Validation struct {
RequiredFirst bool
Errors []*Error
ErrorsMap map[string]*Error
ErrorsMap map[string][]*Error
}
// Clear Clean all ValidationError.
@ -129,7 +129,7 @@ func (v *Validation) HasErrors() bool {
// ErrorMap Return the errors mapped by key.
// If there are multiple validation errors associated with a single key, the
// first one "wins". (Typically the first validation will be the more basic).
func (v *Validation) ErrorMap() map[string]*Error {
func (v *Validation) ErrorMap() map[string][]*Error {
return v.ErrorsMap
}
@ -245,7 +245,21 @@ func (v *Validation) ZipCode(obj interface{}, key string) *Result {
}
func (v *Validation) apply(chk Validator, obj interface{}) *Result {
if chk.IsSatisfied(obj) {
if nil == obj {
if chk.IsSatisfied(obj) {
return &Result{Ok: true}
}
} else if reflect.TypeOf(obj).Kind() == reflect.Ptr {
if reflect.ValueOf(obj).IsNil() {
if chk.IsSatisfied(nil) {
return &Result{Ok: true}
}
} else {
if chk.IsSatisfied(reflect.ValueOf(obj).Elem().Interface()) {
return &Result{Ok: true}
}
}
} else if chk.IsSatisfied(obj) {
return &Result{Ok: true}
}
@ -278,14 +292,35 @@ func (v *Validation) apply(chk Validator, obj interface{}) *Result {
}
}
// AddError adds independent error message for the provided key
func (v *Validation) AddError(key, message string) {
Name := key
Field := ""
parts := strings.Split(key, ".")
if len(parts) == 2 {
Field = parts[0]
Name = parts[1]
}
err := &Error{
Message: message,
Key: key,
Name: Name,
Field: Field,
}
v.setError(err)
}
func (v *Validation) setError(err *Error) {
v.Errors = append(v.Errors, err)
if v.ErrorsMap == nil {
v.ErrorsMap = make(map[string]*Error)
v.ErrorsMap = make(map[string][]*Error)
}
if _, ok := v.ErrorsMap[err.Field]; !ok {
v.ErrorsMap[err.Field] = err
v.ErrorsMap[err.Field] = []*Error{}
}
v.ErrorsMap[err.Field] = append(v.ErrorsMap[err.Field], err)
}
// SetError Set error message for one field in ValidationError
@ -330,13 +365,24 @@ func (v *Validation) Valid(obj interface{}) (b bool, err error) {
return
}
var hasReuired bool
var hasRequired bool
for _, vf := range vfs {
if vf.Name == "Required" {
hasReuired = true
hasRequired = true
}
if !hasReuired && v.RequiredFirst && len(objV.Field(i).String()) == 0 {
currentField := objV.Field(i).Interface()
if objV.Field(i).Kind() == reflect.Ptr {
if objV.Field(i).IsNil() {
currentField = ""
} else {
currentField = objV.Field(i).Elem().Interface()
}
}
chk := Required{""}.IsSatisfied(currentField)
if !hasRequired && v.RequiredFirst && !chk {
if _, ok := CanSkipFuncs[vf.Name]; ok {
continue
}
@ -393,3 +439,9 @@ func (v *Validation) RecursiveValid(objc interface{}) (bool, error) {
}
return pass, err
}
func (v *Validation) CanSkipAlso(skipFunc string) {
if _, ok := CanSkipFuncs[skipFunc]; !ok {
CanSkipFuncs[skipFunc] = struct{}{}
}
}

View File

@ -632,7 +632,7 @@ func (b Base64) GetLimitValue() interface{} {
}
// just for chinese mobile phone number
var mobilePattern = regexp.MustCompile(`^((\+86)|(86))?(1(([35][0-9])|[8][0-9]|[7][06789]|[4][579]))\d{8}$`)
var mobilePattern = regexp.MustCompile(`^((\+86)|(86))?(1(([35][0-9])|[8][0-9]|[7][01356789]|[4][579]))\d{8}$`)
// Mobile check struct
type Mobile struct {

Some files were not shown because too many files have changed in this diff Show More