gotest/web/middleware.go

151 lines
3.6 KiB
Go

package web
import (
"context"
"net/http"
"time"
"github.com/google/uuid"
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
)
type contextKey int
// context keys to be used when inspecting request contexts
const (
requestIDKey contextKey = iota
correlationIDKey
loggerKey
)
// Header names for middleware to use
const (
RequestIDHeader string = "X-Request-Id"
CorrelationIDHeader = "X-Correlation-Id"
)
// CorrelationIDFromRequest gets a correlation id from the request, if it
// exists.
func CorrelationIDFromRequest(req *http.Request) string {
var id string
if str := req.Header.Get(CorrelationIDHeader); str != "" {
id = str
}
return id
}
func getFromContext(ctx context.Context, key contextKey) (value string, ok bool) {
if v := ctx.Value(key); v != nil {
value = v.(string)
ok = true
return
}
value, ok = "", false
return
}
func CorrelationIDFromContext(ctx context.Context) (value string, ok bool) {
return getFromContext(ctx, correlationIDKey)
}
func RequestIDFromContext(ctx context.Context) (value string, ok bool) {
return getFromContext(ctx, requestIDKey)
}
func LoggerFromContext(ctx context.Context) *log.Entry {
if v := ctx.Value(loggerKey); v != nil {
return v.(*log.Entry)
}
return log.WithFields(log.Fields{})
}
type loggingResponseWriter struct {
http.ResponseWriter
size, status int
start time.Time
}
// This idea is taken from the violetear responsewriter
// https://github.com/nbari/violetear/blob/master/responsewriter.go
func newLoggingResponseWriter(w http.ResponseWriter) *loggingResponseWriter {
return &loggingResponseWriter{
ResponseWriter: w,
start: time.Now(),
status: http.StatusOK,
}
}
func (w *loggingResponseWriter) Status() int {
return w.status
}
func (w *loggingResponseWriter) Size() int {
return w.size
}
func (w *loggingResponseWriter) ElapsedTime() string {
return time.Since(w.start).String()
}
func (w *loggingResponseWriter) Write(data []byte) (int, error) {
size, err := w.ResponseWriter.Write(data)
w.size += size
return size, err
}
func (w *loggingResponseWriter) WriteHeader(statusCode int) {
w.status = statusCode
w.ResponseWriter.WriteHeader(statusCode)
}
func middleware(r *mux.Router) *mux.Router {
r.Use(loggingMw)
r.Use(discworldMw)
log.Info("using")
return r
}
func discworldMw(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Clacks-Overhead", "GNU Terry Pratchett")
next.ServeHTTP(w, r)
})
}
func loggingMw(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
newW := newLoggingResponseWriter(w)
correlationID := CorrelationIDFromRequest(r)
if correlationID == "" {
correlationID = uuid.New().String()
}
requestID := uuid.New().String()
contextLogger := log.WithFields(log.Fields{
"requestID": requestID,
"correlationID": correlationID,
})
ctx := r.Context()
ctx = context.WithValue(ctx, requestIDKey, requestID)
ctx = context.WithValue(ctx, correlationIDKey, correlationID)
ctx = context.WithValue(ctx, loggerKey, contextLogger)
r = r.WithContext(ctx)
w.Header().Set(CorrelationIDHeader, correlationID)
w.Header().Set(RequestIDHeader, requestID)
requestLogger := contextLogger.WithFields(
log.Fields{
"path": r.URL.Path,
"query_string": r.URL.RawQuery,
"method": r.Method,
})
requestLogger.Info("Start web request")
next.ServeHTTP(newW, r)
requestLogger.WithFields(
log.Fields{
"status": newW.Status(),
"size": newW.Size(),
"duration": newW.ElapsedTime()}).
Info("End web request")
})
}