diff --git a/.gitignore b/.gitignore index c9a0a10..fc6f927 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ release/ .vscode/ +.htpasswd diff --git a/Makefile b/Makefile index d3f7651..08f90ab 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,7 @@ .SILENT : +export GO111MODULE=on + # Author AUTHOR=github.com/ncarlier diff --git a/README.md b/README.md index a00cd1f..607dd0f 100644 --- a/README.md +++ b/README.md @@ -46,6 +46,7 @@ You can configure the daemon by: | Variable | Default | Description | |----------|---------|-------------| | `APP_LISTEN_ADDR` | `:8080` | HTTP service address | +| `APP_PASSWD_FILE` | `.htpasswd` | Password file for HTTP basic authentication | | `APP_NB_WORKERS` | `2` | The number of workers to start | | `APP_HOOK_TIMEOUT` | `10` | Hook maximum delay before timeout (in second) | | `APP_SCRIPTS_DIR` | `./scripts` | Scripts directory | @@ -64,6 +65,7 @@ You can configure the daemon by: | Parameter | Default | Description | |----------|---------|-------------| | `-l
or --listen
` | `:8080` | HTTP service address | +| `-p or --passwd ` | `.htpasswd` | Password file for HTTP basic authentication | `-d or --debug` | false | Output debug logs | | `--nb-workers ` | `2` | The number of workers to start | | `--scripts ` | `./scripts` | Scripts directory | @@ -199,6 +201,34 @@ SMTP notification configuration: The log file will be sent as an GZIP attachment. +### Authentication + +You can restrict access to webhooks using HTTP basic authentication. + +To activate basic authentication, you have to create a `htpasswd` file: + +```bash +$ # create passwd file the user 'api' +$ htpasswd -B -c .htpasswd api +``` +This command will ask for a password and store it in the htpawsswd file. + +Please note that by default, the daemon will try to load the `.htpasswd` file. + +But you can override this behavior by specifying the location of the file: + +```bash +$ APP_PASSWD_FILE=/etc/webhookd/users.htpasswd +$ # or +$ webhookd -p /etc/webhookd/users.htpasswd +``` + +Once configured, you must call webhooks using basic authentication: + +```bash +$ curl -u api:test -XPOST "http://localhost:8080/echo?msg=hello" +``` + --- diff --git a/config.go b/config.go deleted file mode 100644 index 42d2469..0000000 --- a/config.go +++ /dev/null @@ -1,70 +0,0 @@ -package main - -import ( - "bytes" - "flag" - "os" - "strconv" - - "github.com/ncarlier/webhookd/pkg/auth" -) - -// Config contain global configuration -type Config struct { - ListenAddr *string - NbWorkers *int - Debug *bool - Timeout *int - ScriptDir *string - Authentication *string - AuthenticationParam *string -} - -var config = &Config{ - ListenAddr: flag.String("listen", getEnv("LISTEN_ADDR", ":8080"), "HTTP service address (e.g.address, ':8080')"), - NbWorkers: flag.Int("nb-workers", getIntEnv("NB_WORKERS", 2), "The number of workers to start"), - Debug: flag.Bool("debug", getBoolEnv("DEBUG", false), "Output debug logs"), - Timeout: flag.Int("timeout", getIntEnv("HOOK_TIMEOUT", 10), "Hook maximum delay before timeout (in second)"), - ScriptDir: flag.String("scripts", getEnv("SCRIPTS_DIR", "scripts"), "Scripts directory"), - Authentication: flag.String("auth", getEnv("AUTH", "none"), ""), - AuthenticationParam: flag.String("auth-param", getEnv("AUTH_PARAM", ""), func() string { - authdocwriter := bytes.NewBufferString("Authentication method. Available methods: ") - - for key, method := range auth.AvailableMethods { - authdocwriter.WriteRune('\n') - authdocwriter.WriteString(key) - authdocwriter.WriteRune(':') - authdocwriter.WriteString(method.Usage()) - } - return authdocwriter.String() - }()), -} - -func init() { - flag.StringVar(config.ListenAddr, "l", *config.ListenAddr, "HTTP service (e.g address: ':8080')") - flag.BoolVar(config.Debug, "d", *config.Debug, "Output debug logs") - -} - -func getEnv(key, fallback string) string { - if value, ok := os.LookupEnv("APP_" + key); ok { - return value - } - return fallback -} - -func getIntEnv(key string, fallback int) int { - strValue := getEnv(key, strconv.Itoa(fallback)) - if value, err := strconv.Atoi(strValue); err == nil { - return value - } - return fallback -} - -func getBoolEnv(key string, fallback bool) bool { - strValue := getEnv(key, strconv.FormatBool(fallback)) - if value, err := strconv.ParseBool(strValue); err == nil { - return value - } - return fallback -} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..09f6847 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/ncarlier/webhookd + +require golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..8c4e7ae --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9 h1:mKdxBk7AujPs8kU4m80U72y/zjbZ3UcXC7dClwKbUI0= +golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= diff --git a/main.go b/main.go index 4bd20b6..8d44559 100644 --- a/main.go +++ b/main.go @@ -3,30 +3,17 @@ package main import ( "context" "flag" - "fmt" - "log" "net/http" "os" "os/signal" - "sync/atomic" "time" "github.com/ncarlier/webhookd/pkg/api" - "github.com/ncarlier/webhookd/pkg/auth" + "github.com/ncarlier/webhookd/pkg/config" "github.com/ncarlier/webhookd/pkg/logger" "github.com/ncarlier/webhookd/pkg/worker" ) -type key int - -const ( - requestIDKey key = 0 -) - -var ( - healthy int32 -) - func main() { flag.Parse() @@ -35,47 +22,25 @@ func main() { return } - var authmethod auth.Method - name := *config.Authentication - if _, ok := auth.AvailableMethods[name]; ok { - authmethod = auth.AvailableMethods[name] - if err := authmethod.ParseParam(*config.AuthenticationParam); err != nil { - fmt.Println("Authentication parameter is not valid:", err.Error()) - fmt.Println(authmethod.Usage()) - os.Exit(2) - } - } else { - fmt.Println("Authentication name is not valid:", name) - os.Exit(2) - } + conf := config.Get() level := "info" - if *config.Debug { + if *conf.Debug { level = "debug" } logger.Init(level) logger.Debug.Println("Starting webhookd server...") - logger.Debug.Println("Using Authentication:", name) - authmethod.Init(*config.Debug) - - router := http.NewServeMux() - router.Handle("/", api.Index(*config.Timeout, *config.ScriptDir)) - router.Handle("/healthz", healthz()) - - nextRequestID := func() string { - return fmt.Sprintf("%d", time.Now().UnixNano()) - } server := &http.Server{ - Addr: *config.ListenAddr, - Handler: authmethod.Middleware()(tracing(nextRequestID)(logging(logger.Debug)(router))), + Addr: *conf.ListenAddr, + Handler: api.NewRouter(config.Get()), ErrorLog: logger.Error, } // Start the dispatcher. - logger.Debug.Printf("Starting the dispatcher (%d workers)...\n", *config.NbWorkers) - worker.StartDispatcher(*config.NbWorkers) + logger.Debug.Printf("Starting the dispatcher (%d workers)...\n", *conf.NbWorkers) + worker.StartDispatcher(*conf.NbWorkers) done := make(chan bool) quit := make(chan os.Signal, 1) @@ -84,7 +49,7 @@ func main() { go func() { <-quit logger.Debug.Println("Server is shutting down...") - atomic.StoreInt32(&healthy, 0) + api.Shutdown() ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -96,51 +61,12 @@ func main() { close(done) }() - logger.Info.Println("Server is ready to handle requests at", *config.ListenAddr) - atomic.StoreInt32(&healthy, 1) + logger.Info.Println("Server is ready to handle requests at", *conf.ListenAddr) + api.Start() if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { - logger.Error.Fatalf("Could not listen on %s: %v\n", *config.ListenAddr, err) + logger.Error.Fatalf("Could not listen on %s: %v\n", *conf.ListenAddr, err) } <-done logger.Debug.Println("Server stopped") } - -func healthz() http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if atomic.LoadInt32(&healthy) == 1 { - w.WriteHeader(http.StatusNoContent) - return - } - w.WriteHeader(http.StatusServiceUnavailable) - }) -} - -func logging(logger *log.Logger) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - defer func() { - requestID, ok := r.Context().Value(requestIDKey).(string) - if !ok { - requestID = "unknown" - } - logger.Println(requestID, r.Method, r.URL.Path, r.RemoteAddr, r.UserAgent()) - }() - next.ServeHTTP(w, r) - }) - } -} - -func tracing(nextRequestID func() string) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - requestID := r.Header.Get("X-Request-Id") - if requestID == "" { - requestID = nextRequestID() - } - ctx := context.WithValue(r.Context(), requestIDKey, requestID) - w.Header().Set("X-Request-Id", requestID) - next.ServeHTTP(w, r.WithContext(ctx)) - }) - } -} diff --git a/pkg/api/healthz.go b/pkg/api/healthz.go new file mode 100644 index 0000000..7c1af16 --- /dev/null +++ b/pkg/api/healthz.go @@ -0,0 +1,32 @@ +package api + +import ( + "net/http" + "sync/atomic" + + "github.com/ncarlier/webhookd/pkg/config" +) + +var ( + healthy int32 +) + +// Shutdown set API as stopped +func Shutdown() { + atomic.StoreInt32(&healthy, 0) +} + +// Start set API as started +func Start() { + atomic.StoreInt32(&healthy, 1) +} + +func healthz(conf *config.Config) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if atomic.LoadInt32(&healthy) == 1 { + w.WriteHeader(http.StatusNoContent) + return + } + w.WriteHeader(http.StatusServiceUnavailable) + }) +} diff --git a/pkg/api/api.go b/pkg/api/index.go similarity index 91% rename from pkg/api/api.go rename to pkg/api/index.go index 398db89..5de4f06 100644 --- a/pkg/api/api.go +++ b/pkg/api/index.go @@ -7,6 +7,7 @@ import ( "strconv" "strings" + "github.com/ncarlier/webhookd/pkg/config" "github.com/ncarlier/webhookd/pkg/logger" "github.com/ncarlier/webhookd/pkg/tools" "github.com/ncarlier/webhookd/pkg/worker" @@ -24,10 +25,10 @@ func atoiFallback(str string, fallback int) int { return fallback } -// Index is the main handler of the API. -func Index(timeout int, scrDir string) http.Handler { - defaultTimeout = timeout - scriptDir = scrDir +// index is the main handler of the API. +func index(conf *config.Config) http.Handler { + defaultTimeout = *conf.Timeout + scriptDir = *conf.ScriptDir return http.HandlerFunc(webhookHandler) } diff --git a/pkg/api/router.go b/pkg/api/router.go new file mode 100644 index 0000000..5c8d059 --- /dev/null +++ b/pkg/api/router.go @@ -0,0 +1,36 @@ +package api + +import ( + "fmt" + "net/http" + "time" + + "github.com/ncarlier/webhookd/pkg/auth" + "github.com/ncarlier/webhookd/pkg/config" + "github.com/ncarlier/webhookd/pkg/middleware" +) + +// NewRouter creates router with declared routes +func NewRouter(conf *config.Config) *http.ServeMux { + router := http.NewServeMux() + authenticator := auth.NewAuthenticator(conf) + + nextRequestID := func() string { + return fmt.Sprintf("%d", time.Now().UnixNano()) + } + + for _, route := range routes { + var handler http.Handler + + handler = route.HandlerFunc(conf) + handler = middleware.Logger(handler) + handler = middleware.Tracing(nextRequestID)(handler) + + if authenticator != nil { + handler = middleware.Auth(handler, authenticator) + } + router.Handle(route.Path, handler) + } + + return router +} diff --git a/pkg/api/routes.go b/pkg/api/routes.go new file mode 100644 index 0000000..14ebbc9 --- /dev/null +++ b/pkg/api/routes.go @@ -0,0 +1,33 @@ +package api + +import ( + "net/http" + + "github.com/ncarlier/webhookd/pkg/config" +) + +// HandlerFunc custom function handler +type HandlerFunc func(conf *config.Config) http.Handler + +// Route is the structure of an HTTP route definition +type Route struct { + Method string + Path string + HandlerFunc HandlerFunc +} + +// Routes is a list of Route +type Routes []Route + +var routes = Routes{ + Route{ + "GET", + "/", + index, + }, + Route{ + "GET", + "/healtz", + healthz, + }, +} diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go new file mode 100644 index 0000000..0d831a8 --- /dev/null +++ b/pkg/auth/auth.go @@ -0,0 +1,23 @@ +package auth + +import ( + "net/http" + + "github.com/ncarlier/webhookd/pkg/config" + "github.com/ncarlier/webhookd/pkg/logger" +) + +// Authenticator is a generic interface to validate an HTTP request +type Authenticator interface { + Validate(r *http.Request) bool +} + +// NewAuthenticator creates new authenticator form the configuration +func NewAuthenticator(conf *config.Config) Authenticator { + authenticator, err := NewHtpasswdFromFile(*conf.PasswdFile) + if err != nil { + logger.Debug.Printf("unable to load htpasswd file: \"%s\" (%s)\n", *conf.PasswdFile, err) + return nil + } + return authenticator +} diff --git a/pkg/auth/authmethod.go b/pkg/auth/authmethod.go deleted file mode 100644 index 775e218..0000000 --- a/pkg/auth/authmethod.go +++ /dev/null @@ -1,27 +0,0 @@ -package auth - -import "net/http" - -// Method an interface describing an authentication method -type Method interface { - // Called after ParseParam method. - // auth.Method should initialize itself here and get ready to receive requests. - // Logger has been initialized so it is safe to call logger methods here. - Init(debug bool) - // Return Method Usage Info - Usage() string - // Parse the parameter passed through the -authparam flag - // Logger is not initialized at this state so do NOT call logger methods - // If the parameter is unacceptable, return an error and main should exit - ParseParam(string) error - // Return a middleware to handle connections. - Middleware() func(http.Handler) http.Handler -} - -var ( - // AvailableMethods Returns a map of available auth methods - AvailableMethods = map[string]Method{ - "none": new(noAuth), - "basic": new(basicAuth), - } -) diff --git a/pkg/auth/basic.go b/pkg/auth/basic.go deleted file mode 100644 index cfe4e84..0000000 --- a/pkg/auth/basic.go +++ /dev/null @@ -1,59 +0,0 @@ -package auth - -import ( - "errors" - "fmt" - "net/http" - "strings" - - "github.com/ncarlier/webhookd/pkg/logger" -) - -type basicAuth struct { - username string - password string - authheader string -} - -func (c *basicAuth) Init(_ bool) {} - -func (c *basicAuth) Usage() string { - return "HTTP Basic Auth. Usage: -auth basic -authparam :[:] (example: -auth basic -auth-param foo:bar)" -} - -func (c *basicAuth) ParseParam(param string) error { - res := strings.Split(param, ":") - realm := "Authentication required." - switch len(res) { - case 3: - realm = res[2] - fallthrough - case 2: - c.username, c.password = res[0], res[1] - c.authheader = fmt.Sprintf("Basic realm=\"%s\"", realm) - return nil - } - return errors.New("Invalid Auth param") - -} - -// BasicAuth HTTP Basic Auth implementation -func (c *basicAuth) Middleware() func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if username, password, ok := r.BasicAuth(); ok && username == c.username && password == c.password { - logger.Info.Printf("HTTP Basic Auth: %s PASSED\n", username) - next.ServeHTTP(w, r) - } else if !ok { - logger.Debug.Println("HTTP Basic Auth: Auth header not present.") - w.Header().Add("WWW-Authenticate", c.authheader) - w.WriteHeader(401) - w.Write([]byte("Authentication required.")) - } else { - logger.Warning.Printf("HTTP Basic Auth: Invalid credentials for username %s\n", username) - w.WriteHeader(403) - w.Write([]byte("Forbidden.")) - } - }) - } -} diff --git a/pkg/auth/htpasswd-file.go b/pkg/auth/htpasswd-file.go new file mode 100644 index 0000000..29fd440 --- /dev/null +++ b/pkg/auth/htpasswd-file.go @@ -0,0 +1,95 @@ +package auth + +import ( + "crypto/sha1" + "encoding/base64" + "encoding/csv" + "net/http" + "os" + "regexp" + "strings" + + "golang.org/x/crypto/bcrypt" +) + +var ( + shaRe = regexp.MustCompile(`^{SHA}`) + bcrRe = regexp.MustCompile(`^\$2b\$|^\$2a\$|^\$2y\$`) +) + +// HtpasswdFile is a map for usernames to passwords. +type HtpasswdFile struct { + path string + users map[string]string +} + +// NewHtpasswdFromFile reads the users and passwords from a htpasswd file and returns them. +func NewHtpasswdFromFile(path string) (*HtpasswdFile, error) { + r, err := os.Open(path) + if err != nil { + return nil, err + } + defer r.Close() + + cr := csv.NewReader(r) + cr.Comma = ':' + cr.Comment = '#' + cr.TrimLeadingSpace = true + + records, err := cr.ReadAll() + if err != nil { + return nil, err + } + + users := make(map[string]string) + for _, record := range records { + users[record[0]] = record[1] + } + + return &HtpasswdFile{ + path: path, + users: users, + }, nil +} + +// Validate HTTP request credentials +func (h *HtpasswdFile) Validate(r *http.Request) bool { + s := strings.SplitN(r.Header.Get("Authorization"), " ", 2) + if len(s) != 2 { + return false + } + + b, err := base64.StdEncoding.DecodeString(s[1]) + if err != nil { + return false + } + + pair := strings.SplitN(string(b), ":", 2) + if len(pair) != 2 { + return false + } + + return h.validateCredentials(pair[0], pair[1]) +} + +func (h *HtpasswdFile) validateCredentials(user string, password string) bool { + pwd, exists := h.users[user] + if !exists { + return false + } + + switch { + case shaRe.MatchString(pwd): + d := sha1.New() + _, _ = d.Write([]byte(password)) + if pwd[5:] == base64.StdEncoding.EncodeToString(d.Sum(nil)) { + return true + } + case bcrRe.MatchString(pwd): + err := bcrypt.CompareHashAndPassword([]byte(pwd), []byte(password)) + if err == nil { + return true + } + } + return false +} diff --git a/pkg/auth/htpasswd-file_test.go b/pkg/auth/htpasswd-file_test.go new file mode 100644 index 0000000..c131350 --- /dev/null +++ b/pkg/auth/htpasswd-file_test.go @@ -0,0 +1,15 @@ +package auth + +import ( + "testing" + + "github.com/ncarlier/webhookd/pkg/assert" +) + +func TestValidateCredentials(t *testing.T) { + htpasswdFile, err := NewHtpasswdFromFile("test.htpasswd") + assert.Nil(t, err, ".htpasswd file should be loaded") + assert.NotNil(t, htpasswdFile, ".htpasswd file should be loaded") + assert.Equal(t, true, htpasswdFile.validateCredentials("foo", "bar"), "credentials should be valid") + assert.Equal(t, false, htpasswdFile.validateCredentials("foo", "bir"), "credentials should not be valid") +} diff --git a/pkg/auth/none.go b/pkg/auth/none.go deleted file mode 100644 index 6e8905a..0000000 --- a/pkg/auth/none.go +++ /dev/null @@ -1,25 +0,0 @@ -package auth - -import ( - "net/http" -) - -type noAuth struct { -} - -func (c *noAuth) Usage() string { - return "No Auth. Usage: -auth none" -} - -func (c *noAuth) Init(_ bool) {} - -func (c *noAuth) ParseParam(_ string) error { - return nil -} - -// NoAuth A Nop Auth middleware -func (c *noAuth) Middleware() func(http.Handler) http.Handler { - return func(h http.Handler) http.Handler { - return h - } -} diff --git a/pkg/auth/test.htpasswd b/pkg/auth/test.htpasswd new file mode 100644 index 0000000..2d6e571 --- /dev/null +++ b/pkg/auth/test.htpasswd @@ -0,0 +1,2 @@ +# htpasswd -B -c test.htpasswd foo +foo:$2y$05$068L1J0kA3FEh8jHSlnluut4gYleWd47Ig/AWztz8/8bQS6tHvtd. diff --git a/pkg/config/config.go b/pkg/config/config.go new file mode 100644 index 0000000..00bca53 --- /dev/null +++ b/pkg/config/config.go @@ -0,0 +1,60 @@ +package config + +import ( + "flag" + "os" + "strconv" +) + +// Config contain global configuration +type Config struct { + ListenAddr *string + NbWorkers *int + Debug *bool + Timeout *int + ScriptDir *string + PasswdFile *string +} + +var config = &Config{ + ListenAddr: flag.String("listen", getEnv("LISTEN_ADDR", ":8080"), "HTTP service address (e.g.address, ':8080')"), + NbWorkers: flag.Int("nb-workers", getIntEnv("NB_WORKERS", 2), "The number of workers to start"), + Debug: flag.Bool("debug", getBoolEnv("DEBUG", false), "Output debug logs"), + Timeout: flag.Int("timeout", getIntEnv("HOOK_TIMEOUT", 10), "Hook maximum delay before timeout (in second)"), + ScriptDir: flag.String("scripts", getEnv("SCRIPTS_DIR", "scripts"), "Scripts directory"), + PasswdFile: flag.String("passwd", getEnv("PASSWD_FILE", ".htpasswd"), "Password file (encoded with htpasswd)"), +} + +func init() { + flag.StringVar(config.ListenAddr, "l", *config.ListenAddr, "HTTP service (e.g address: ':8080')") + flag.BoolVar(config.Debug, "d", *config.Debug, "Output debug logs") + flag.StringVar(config.PasswdFile, "p", *config.PasswdFile, "Password file (encoded with htpasswd)") +} + +// Get global configuration +func Get() *Config { + return config +} + +func getEnv(key, fallback string) string { + if value, ok := os.LookupEnv("APP_" + key); ok { + return value + } + return fallback +} + +func getIntEnv(key string, fallback int) int { + strValue := getEnv(key, strconv.Itoa(fallback)) + if value, err := strconv.Atoi(strValue); err == nil { + return value + } + return fallback +} + +func getBoolEnv(key string, fallback bool) bool { + strValue := getEnv(key, strconv.FormatBool(fallback)) + if value, err := strconv.ParseBool(strValue); err == nil { + return value + } + return fallback +} diff --git a/pkg/middleware/auth.go b/pkg/middleware/auth.go new file mode 100644 index 0000000..0cff402 --- /dev/null +++ b/pkg/middleware/auth.go @@ -0,0 +1,20 @@ +package middleware + +import ( + "net/http" + + "github.com/ncarlier/webhookd/pkg/auth" +) + +// Auth is a middleware to checks HTTP request credentials +func Auth(inner http.Handler, authn auth.Authenticator) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if authn.Validate(r) { + inner.ServeHTTP(w, r) + return + } + w.Header().Set("WWW-Authenticate", `Basic realm="Ah ah ah, you didn't say the magic word"`) + w.WriteHeader(401) + w.Write([]byte("401 Unauthorized\n")) + }) +} diff --git a/pkg/middleware/logger.go b/pkg/middleware/logger.go new file mode 100644 index 0000000..c172de2 --- /dev/null +++ b/pkg/middleware/logger.go @@ -0,0 +1,29 @@ +package middleware + +import ( + "net/http" + "time" + + "github.com/ncarlier/webhookd/pkg/logger" +) + +type key int + +const ( + requestIDKey key = 0 +) + +// Logger is a middleware to log HTTP request +func Logger(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + defer func() { + requestID, ok := r.Context().Value(requestIDKey).(string) + if !ok { + requestID = "unknown" + } + logger.Info.Println(requestID, r.Method, r.URL.Path, r.RemoteAddr, r.UserAgent(), time.Since(start)) + }() + next.ServeHTTP(w, r) + }) +} diff --git a/pkg/middleware/tracing.go b/pkg/middleware/tracing.go new file mode 100644 index 0000000..e69029a --- /dev/null +++ b/pkg/middleware/tracing.go @@ -0,0 +1,21 @@ +package middleware + +import ( + "context" + "net/http" +) + +// Tracing is a middleware to trace HTTP request +func Tracing(nextRequestID func() string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestID := r.Header.Get("X-Request-Id") + if requestID == "" { + requestID = nextRequestID() + } + ctx := context.WithValue(r.Context(), requestIDKey, requestID) + w.Header().Set("X-Request-Id", requestID) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +}