mirror of
https://github.com/ncarlier/webhookd.git
synced 2025-04-12 17:47:30 +00:00
refactore(middleware): middlewares small refactoring
This commit is contained in:
parent
75aac478a2
commit
1d3da80680
|
@ -2,9 +2,11 @@ package api
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ncarlier/webhookd/pkg/strcase"
|
||||
)
|
||||
|
@ -42,3 +44,7 @@ func HTTPHeadersToShellVars(h http.Header) []string {
|
|||
}
|
||||
return params
|
||||
}
|
||||
|
||||
func nextRequestID() string {
|
||||
return fmt.Sprintf("%d", time.Now().UnixNano())
|
||||
}
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/ncarlier/webhookd/pkg/auth"
|
||||
"github.com/ncarlier/webhookd/pkg/config"
|
||||
|
@ -12,18 +10,19 @@ import (
|
|||
"github.com/ncarlier/webhookd/pkg/pubkey"
|
||||
)
|
||||
|
||||
func nextRequestID() string {
|
||||
return fmt.Sprintf("%d", time.Now().UnixNano())
|
||||
var commonMiddlewares = []middleware.Middleware{
|
||||
middleware.Cors,
|
||||
middleware.Tracing(nextRequestID),
|
||||
middleware.Logger,
|
||||
}
|
||||
|
||||
// NewRouter creates router with declared routes
|
||||
func NewRouter(conf *config.Config) *http.ServeMux {
|
||||
router := http.NewServeMux()
|
||||
|
||||
// Load authenticator...
|
||||
authenticator, err := auth.NewHtpasswdFromFile(conf.PasswdFile)
|
||||
if err != nil {
|
||||
logger.Debug.Printf("unable to load htpasswd file (\"%s\"): %s\n", conf.PasswdFile, err)
|
||||
var middlewares = commonMiddlewares
|
||||
if conf.TLSListenAddr != "" {
|
||||
middlewares = append(middlewares, middleware.HSTS)
|
||||
}
|
||||
|
||||
// Load key store...
|
||||
|
@ -31,25 +30,27 @@ func NewRouter(conf *config.Config) *http.ServeMux {
|
|||
if err != nil {
|
||||
logger.Warning.Printf("unable to load key store (\"%s\"): %s\n", conf.KeyStoreURI, err)
|
||||
}
|
||||
if keystore != nil {
|
||||
middlewares = append(middlewares, middleware.HTTPSignature(keystore))
|
||||
}
|
||||
|
||||
// Load authenticator...
|
||||
authenticator, err := auth.NewHtpasswdFromFile(conf.PasswdFile)
|
||||
if err != nil {
|
||||
logger.Debug.Printf("unable to load htpasswd file (\"%s\"): %s\n", conf.PasswdFile, err)
|
||||
}
|
||||
if authenticator != nil {
|
||||
middlewares = append(middlewares, middleware.AuthN(authenticator))
|
||||
}
|
||||
|
||||
// Register HTTP routes...
|
||||
for _, route := range routes {
|
||||
var handler http.Handler
|
||||
|
||||
handler = route.HandlerFunc(conf)
|
||||
handler = middleware.Method(handler, route.Methods)
|
||||
handler = middleware.Cors(handler)
|
||||
if conf.TLSListenAddr != "" {
|
||||
handler = middleware.HSTS(handler)
|
||||
handler := route.HandlerFunc(conf)
|
||||
for _, mw := range route.Middlewares {
|
||||
handler = mw(handler)
|
||||
}
|
||||
handler = middleware.Logger(handler)
|
||||
handler = middleware.Tracing(nextRequestID)(handler)
|
||||
|
||||
if keystore != nil {
|
||||
handler = middleware.HTTPSignature(handler, keystore)
|
||||
}
|
||||
if authenticator != nil {
|
||||
handler = middleware.Auth(handler, authenticator)
|
||||
for _, mw := range middlewares {
|
||||
handler = mw(handler)
|
||||
}
|
||||
router.Handle(route.Path, handler)
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"net/http"
|
||||
|
||||
"github.com/ncarlier/webhookd/pkg/config"
|
||||
"github.com/ncarlier/webhookd/pkg/middleware"
|
||||
)
|
||||
|
||||
// HandlerFunc custom function handler
|
||||
|
@ -11,28 +12,24 @@ type HandlerFunc func(conf *config.Config) http.Handler
|
|||
|
||||
// Route is the structure of an HTTP route definition
|
||||
type Route struct {
|
||||
Methods []string
|
||||
Path string
|
||||
HandlerFunc HandlerFunc
|
||||
Middlewares []middleware.Middleware
|
||||
}
|
||||
|
||||
func route(path string, handler HandlerFunc, middlewares ...middleware.Middleware) Route {
|
||||
return Route{
|
||||
Path: path,
|
||||
HandlerFunc: handler,
|
||||
Middlewares: middlewares,
|
||||
}
|
||||
}
|
||||
|
||||
// Routes is a list of Route
|
||||
type Routes []Route
|
||||
|
||||
var routes = Routes{
|
||||
Route{
|
||||
[]string{"GET", "POST"},
|
||||
"/",
|
||||
index,
|
||||
},
|
||||
Route{
|
||||
[]string{"GET"},
|
||||
"/healthz",
|
||||
healthz,
|
||||
},
|
||||
Route{
|
||||
[]string{"GET"},
|
||||
"/varz",
|
||||
varz,
|
||||
},
|
||||
route("/", index, middleware.Methods("GET", "POST")),
|
||||
route("/healthz", healthz, middleware.Methods("GET")),
|
||||
route("/varz", varz, middleware.Methods("GET")),
|
||||
}
|
||||
|
|
36
pkg/logger/color.go
Normal file
36
pkg/logger/color.go
Normal file
|
@ -0,0 +1,36 @@
|
|||
package logger
|
||||
|
||||
var (
|
||||
nocolor = "\033[0m"
|
||||
red = "\033[0;31m"
|
||||
green = "\033[0;32m"
|
||||
orange = "\033[0;33m"
|
||||
blue = "\033[0;34m"
|
||||
purple = "\033[0;35m"
|
||||
cyan = "\033[0;36m"
|
||||
gray = "\033[0;37m"
|
||||
)
|
||||
|
||||
func colorize(text string, color string) string {
|
||||
return color + text + nocolor
|
||||
}
|
||||
|
||||
// Gray ANSI color applied to a string
|
||||
func Gray(text string) string {
|
||||
return colorize(text, gray)
|
||||
}
|
||||
|
||||
// Green ANSI color applied to a string
|
||||
func Green(text string) string {
|
||||
return colorize(text, green)
|
||||
}
|
||||
|
||||
// Orange ANSI color applied to a string
|
||||
func Orange(text string) string {
|
||||
return colorize(text, orange)
|
||||
}
|
||||
|
||||
// Red ANSI color applied to a string
|
||||
func Red(text string) string {
|
||||
return colorize(text, red)
|
||||
}
|
|
@ -42,8 +42,8 @@ func Init(level string) {
|
|||
commonFlags = log.LstdFlags | log.Lmicroseconds | log.Lshortfile
|
||||
}
|
||||
|
||||
Debug = log.New(debugHandle, "DBG ", commonFlags)
|
||||
Info = log.New(infoHandle, "INF ", commonFlags)
|
||||
Warning = log.New(warnHandle, "WRN ", commonFlags)
|
||||
Error = log.New(errorHandle, "ERR ", commonFlags)
|
||||
Debug = log.New(debugHandle, Gray("DBG "), commonFlags)
|
||||
Info = log.New(infoHandle, Green("INF "), commonFlags)
|
||||
Warning = log.New(warnHandle, Orange("WRN "), commonFlags)
|
||||
Error = log.New(errorHandle, Red("ERR "), commonFlags)
|
||||
}
|
||||
|
|
|
@ -1,20 +0,0 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/ncarlier/webhookd/pkg/auth"
|
||||
)
|
||||
|
||||
// Auth is a middleware to checks HTTP request credentials
|
||||
func Auth(inner http.Handler, authn auth.Authenticator) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if authn.Validate(r) {
|
||||
inner.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
w.Header().Set("WWW-Authenticate", `Basic realm="Ah ah ah, you didn't say the magic word"`)
|
||||
w.WriteHeader(401)
|
||||
w.Write([]byte("401 Unauthorized\n"))
|
||||
})
|
||||
}
|
22
pkg/middleware/authn.go
Normal file
22
pkg/middleware/authn.go
Normal file
|
@ -0,0 +1,22 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/ncarlier/webhookd/pkg/auth"
|
||||
)
|
||||
|
||||
// AuthN is a middleware to checks HTTP request credentials
|
||||
func AuthN(authenticator auth.Authenticator) Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if authenticator.Validate(r) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
w.Header().Set("WWW-Authenticate", `Basic realm="Ah ah ah, you didn't say the magic word"`)
|
||||
w.WriteHeader(401)
|
||||
w.Write([]byte("401 Unauthorized\n"))
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,23 +0,0 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Method is a middleware to check that the request use the correct HTTP method
|
||||
func Method(inner http.Handler, methods []string) http.Handler {
|
||||
allowedMethods := make(map[string]struct{}, len(methods))
|
||||
for _, s := range methods {
|
||||
allowedMethods[s] = struct{}{}
|
||||
}
|
||||
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if _, ok := allowedMethods[r.Method]; ok {
|
||||
inner.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(405)
|
||||
w.Write([]byte("405 Method Not Allowed\n"))
|
||||
return
|
||||
})
|
||||
}
|
25
pkg/middleware/methods.go
Normal file
25
pkg/middleware/methods.go
Normal file
|
@ -0,0 +1,25 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Methods is a middleware to check that the request use the correct HTTP method
|
||||
func Methods(methods ...string) Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
allowedMethods := make(map[string]struct{}, len(methods))
|
||||
for _, s := range methods {
|
||||
allowedMethods[s] = struct{}{}
|
||||
}
|
||||
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if _, ok := allowedMethods[r.Method]; ok {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(405)
|
||||
w.Write([]byte("405 Method Not Allowed\n"))
|
||||
return
|
||||
})
|
||||
}
|
||||
}
|
|
@ -8,27 +8,29 @@ import (
|
|||
)
|
||||
|
||||
// HTTPSignature is a middleware to checks HTTP request signature
|
||||
func HTTPSignature(inner http.Handler, keyStore pubkey.KeyStore) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
verifier, err := httpsig.NewVerifier(r)
|
||||
if err != nil {
|
||||
w.WriteHeader(400)
|
||||
w.Write([]byte("invalid HTTP signature: " + err.Error()))
|
||||
return
|
||||
}
|
||||
pubKeyID := verifier.KeyId()
|
||||
pubKey, algo, err := keyStore.Get(pubKeyID)
|
||||
if err != nil {
|
||||
w.WriteHeader(400)
|
||||
w.Write([]byte("invalid HTTP signature: " + err.Error()))
|
||||
return
|
||||
}
|
||||
err = verifier.Verify(pubKey, algo)
|
||||
if err != nil {
|
||||
w.WriteHeader(400)
|
||||
w.Write([]byte("invalid HTTP signature: " + err.Error()))
|
||||
return
|
||||
}
|
||||
inner.ServeHTTP(w, r)
|
||||
})
|
||||
func HTTPSignature(keyStore pubkey.KeyStore) Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
verifier, err := httpsig.NewVerifier(r)
|
||||
if err != nil {
|
||||
w.WriteHeader(400)
|
||||
w.Write([]byte("invalid HTTP signature: " + err.Error()))
|
||||
return
|
||||
}
|
||||
pubKeyID := verifier.KeyId()
|
||||
pubKey, algo, err := keyStore.Get(pubKeyID)
|
||||
if err != nil {
|
||||
w.WriteHeader(400)
|
||||
w.Write([]byte("invalid HTTP signature: " + err.Error()))
|
||||
return
|
||||
}
|
||||
err = verifier.Verify(pubKey, algo)
|
||||
if err != nil {
|
||||
w.WriteHeader(400)
|
||||
w.Write([]byte("invalid HTTP signature: " + err.Error()))
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
)
|
||||
|
||||
// Tracing is a middleware to trace HTTP request
|
||||
func Tracing(nextRequestID func() string) func(http.Handler) http.Handler {
|
||||
func Tracing(nextRequestID func() string) Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestID := r.Header.Get("X-Request-Id")
|
||||
|
|
6
pkg/middleware/types.go
Normal file
6
pkg/middleware/types.go
Normal file
|
@ -0,0 +1,6 @@
|
|||
package middleware
|
||||
|
||||
import "net/http"
|
||||
|
||||
// Middleware function definition
|
||||
type Middleware func(inner http.Handler) http.Handler
|
Loading…
Reference in New Issue
Block a user