refactore(middleware): middlewares small refactoring

This commit is contained in:
Nicolas Carlier 2020-02-29 08:15:12 +00:00
parent 75aac478a2
commit 1d3da80680
12 changed files with 162 additions and 110 deletions

View File

@ -2,9 +2,11 @@ package api
import (
"bytes"
"fmt"
"net/http"
"net/url"
"strings"
"time"
"github.com/ncarlier/webhookd/pkg/strcase"
)
@ -42,3 +44,7 @@ func HTTPHeadersToShellVars(h http.Header) []string {
}
return params
}
func nextRequestID() string {
return fmt.Sprintf("%d", time.Now().UnixNano())
}

View File

@ -1,9 +1,7 @@
package api
import (
"fmt"
"net/http"
"time"
"github.com/ncarlier/webhookd/pkg/auth"
"github.com/ncarlier/webhookd/pkg/config"
@ -12,18 +10,19 @@ import (
"github.com/ncarlier/webhookd/pkg/pubkey"
)
func nextRequestID() string {
return fmt.Sprintf("%d", time.Now().UnixNano())
var commonMiddlewares = []middleware.Middleware{
middleware.Cors,
middleware.Tracing(nextRequestID),
middleware.Logger,
}
// NewRouter creates router with declared routes
func NewRouter(conf *config.Config) *http.ServeMux {
router := http.NewServeMux()
// Load authenticator...
authenticator, err := auth.NewHtpasswdFromFile(conf.PasswdFile)
if err != nil {
logger.Debug.Printf("unable to load htpasswd file (\"%s\"): %s\n", conf.PasswdFile, err)
var middlewares = commonMiddlewares
if conf.TLSListenAddr != "" {
middlewares = append(middlewares, middleware.HSTS)
}
// Load key store...
@ -31,25 +30,27 @@ func NewRouter(conf *config.Config) *http.ServeMux {
if err != nil {
logger.Warning.Printf("unable to load key store (\"%s\"): %s\n", conf.KeyStoreURI, err)
}
if keystore != nil {
middlewares = append(middlewares, middleware.HTTPSignature(keystore))
}
// Load authenticator...
authenticator, err := auth.NewHtpasswdFromFile(conf.PasswdFile)
if err != nil {
logger.Debug.Printf("unable to load htpasswd file (\"%s\"): %s\n", conf.PasswdFile, err)
}
if authenticator != nil {
middlewares = append(middlewares, middleware.AuthN(authenticator))
}
// Register HTTP routes...
for _, route := range routes {
var handler http.Handler
handler = route.HandlerFunc(conf)
handler = middleware.Method(handler, route.Methods)
handler = middleware.Cors(handler)
if conf.TLSListenAddr != "" {
handler = middleware.HSTS(handler)
handler := route.HandlerFunc(conf)
for _, mw := range route.Middlewares {
handler = mw(handler)
}
handler = middleware.Logger(handler)
handler = middleware.Tracing(nextRequestID)(handler)
if keystore != nil {
handler = middleware.HTTPSignature(handler, keystore)
}
if authenticator != nil {
handler = middleware.Auth(handler, authenticator)
for _, mw := range middlewares {
handler = mw(handler)
}
router.Handle(route.Path, handler)
}

View File

@ -4,6 +4,7 @@ import (
"net/http"
"github.com/ncarlier/webhookd/pkg/config"
"github.com/ncarlier/webhookd/pkg/middleware"
)
// HandlerFunc custom function handler
@ -11,28 +12,24 @@ type HandlerFunc func(conf *config.Config) http.Handler
// Route is the structure of an HTTP route definition
type Route struct {
Methods []string
Path string
HandlerFunc HandlerFunc
Middlewares []middleware.Middleware
}
func route(path string, handler HandlerFunc, middlewares ...middleware.Middleware) Route {
return Route{
Path: path,
HandlerFunc: handler,
Middlewares: middlewares,
}
}
// Routes is a list of Route
type Routes []Route
var routes = Routes{
Route{
[]string{"GET", "POST"},
"/",
index,
},
Route{
[]string{"GET"},
"/healthz",
healthz,
},
Route{
[]string{"GET"},
"/varz",
varz,
},
route("/", index, middleware.Methods("GET", "POST")),
route("/healthz", healthz, middleware.Methods("GET")),
route("/varz", varz, middleware.Methods("GET")),
}

36
pkg/logger/color.go Normal file
View File

@ -0,0 +1,36 @@
package logger
var (
nocolor = "\033[0m"
red = "\033[0;31m"
green = "\033[0;32m"
orange = "\033[0;33m"
blue = "\033[0;34m"
purple = "\033[0;35m"
cyan = "\033[0;36m"
gray = "\033[0;37m"
)
func colorize(text string, color string) string {
return color + text + nocolor
}
// Gray ANSI color applied to a string
func Gray(text string) string {
return colorize(text, gray)
}
// Green ANSI color applied to a string
func Green(text string) string {
return colorize(text, green)
}
// Orange ANSI color applied to a string
func Orange(text string) string {
return colorize(text, orange)
}
// Red ANSI color applied to a string
func Red(text string) string {
return colorize(text, red)
}

View File

@ -42,8 +42,8 @@ func Init(level string) {
commonFlags = log.LstdFlags | log.Lmicroseconds | log.Lshortfile
}
Debug = log.New(debugHandle, "DBG ", commonFlags)
Info = log.New(infoHandle, "INF ", commonFlags)
Warning = log.New(warnHandle, "WRN ", commonFlags)
Error = log.New(errorHandle, "ERR ", commonFlags)
Debug = log.New(debugHandle, Gray("DBG "), commonFlags)
Info = log.New(infoHandle, Green("INF "), commonFlags)
Warning = log.New(warnHandle, Orange("WRN "), commonFlags)
Error = log.New(errorHandle, Red("ERR "), commonFlags)
}

View File

@ -1,20 +0,0 @@
package middleware
import (
"net/http"
"github.com/ncarlier/webhookd/pkg/auth"
)
// Auth is a middleware to checks HTTP request credentials
func Auth(inner http.Handler, authn auth.Authenticator) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if authn.Validate(r) {
inner.ServeHTTP(w, r)
return
}
w.Header().Set("WWW-Authenticate", `Basic realm="Ah ah ah, you didn't say the magic word"`)
w.WriteHeader(401)
w.Write([]byte("401 Unauthorized\n"))
})
}

22
pkg/middleware/authn.go Normal file
View File

@ -0,0 +1,22 @@
package middleware
import (
"net/http"
"github.com/ncarlier/webhookd/pkg/auth"
)
// AuthN is a middleware to checks HTTP request credentials
func AuthN(authenticator auth.Authenticator) Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if authenticator.Validate(r) {
next.ServeHTTP(w, r)
return
}
w.Header().Set("WWW-Authenticate", `Basic realm="Ah ah ah, you didn't say the magic word"`)
w.WriteHeader(401)
w.Write([]byte("401 Unauthorized\n"))
})
}
}

View File

@ -1,23 +0,0 @@
package middleware
import (
"net/http"
)
// Method is a middleware to check that the request use the correct HTTP method
func Method(inner http.Handler, methods []string) http.Handler {
allowedMethods := make(map[string]struct{}, len(methods))
for _, s := range methods {
allowedMethods[s] = struct{}{}
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if _, ok := allowedMethods[r.Method]; ok {
inner.ServeHTTP(w, r)
return
}
w.WriteHeader(405)
w.Write([]byte("405 Method Not Allowed\n"))
return
})
}

25
pkg/middleware/methods.go Normal file
View File

@ -0,0 +1,25 @@
package middleware
import (
"net/http"
)
// Methods is a middleware to check that the request use the correct HTTP method
func Methods(methods ...string) Middleware {
return func(next http.Handler) http.Handler {
allowedMethods := make(map[string]struct{}, len(methods))
for _, s := range methods {
allowedMethods[s] = struct{}{}
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if _, ok := allowedMethods[r.Method]; ok {
next.ServeHTTP(w, r)
return
}
w.WriteHeader(405)
w.Write([]byte("405 Method Not Allowed\n"))
return
})
}
}

View File

@ -8,27 +8,29 @@ import (
)
// HTTPSignature is a middleware to checks HTTP request signature
func HTTPSignature(inner http.Handler, keyStore pubkey.KeyStore) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
verifier, err := httpsig.NewVerifier(r)
if err != nil {
w.WriteHeader(400)
w.Write([]byte("invalid HTTP signature: " + err.Error()))
return
}
pubKeyID := verifier.KeyId()
pubKey, algo, err := keyStore.Get(pubKeyID)
if err != nil {
w.WriteHeader(400)
w.Write([]byte("invalid HTTP signature: " + err.Error()))
return
}
err = verifier.Verify(pubKey, algo)
if err != nil {
w.WriteHeader(400)
w.Write([]byte("invalid HTTP signature: " + err.Error()))
return
}
inner.ServeHTTP(w, r)
})
func HTTPSignature(keyStore pubkey.KeyStore) Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
verifier, err := httpsig.NewVerifier(r)
if err != nil {
w.WriteHeader(400)
w.Write([]byte("invalid HTTP signature: " + err.Error()))
return
}
pubKeyID := verifier.KeyId()
pubKey, algo, err := keyStore.Get(pubKeyID)
if err != nil {
w.WriteHeader(400)
w.Write([]byte("invalid HTTP signature: " + err.Error()))
return
}
err = verifier.Verify(pubKey, algo)
if err != nil {
w.WriteHeader(400)
w.Write([]byte("invalid HTTP signature: " + err.Error()))
return
}
next.ServeHTTP(w, r)
})
}
}

View File

@ -6,7 +6,7 @@ import (
)
// Tracing is a middleware to trace HTTP request
func Tracing(nextRequestID func() string) func(http.Handler) http.Handler {
func Tracing(nextRequestID func() string) Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestID := r.Header.Get("X-Request-Id")

6
pkg/middleware/types.go Normal file
View File

@ -0,0 +1,6 @@
package middleware
import "net/http"
// Middleware function definition
type Middleware func(inner http.Handler) http.Handler