151 lines
3.6 KiB
Go
151 lines
3.6 KiB
Go
package web
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/gorilla/mux"
|
|
log "github.com/sirupsen/logrus"
|
|
)
|
|
|
|
type contextKey int
|
|
|
|
// context keys to be used when inspecting request contexts
|
|
const (
|
|
requestIDKey contextKey = iota
|
|
correlationIDKey
|
|
loggerKey
|
|
)
|
|
|
|
// Header names for middleware to use
|
|
const (
|
|
RequestIDHeader string = "X-Request-Id"
|
|
CorrelationIDHeader = "X-Correlation-Id"
|
|
)
|
|
|
|
// CorrelationIDFromRequest gets a correlation id from the request, if it
|
|
// exists.
|
|
func CorrelationIDFromRequest(req *http.Request) string {
|
|
var id string
|
|
if str := req.Header.Get(CorrelationIDHeader); str != "" {
|
|
id = str
|
|
}
|
|
return id
|
|
}
|
|
|
|
func getFromContext(ctx context.Context, key contextKey) (value string, ok bool) {
|
|
if v := ctx.Value(key); v != nil {
|
|
value = v.(string)
|
|
ok = true
|
|
return
|
|
}
|
|
value, ok = "", false
|
|
return
|
|
}
|
|
|
|
func CorrelationIDFromContext(ctx context.Context) (value string, ok bool) {
|
|
return getFromContext(ctx, correlationIDKey)
|
|
}
|
|
|
|
func RequestIDFromContext(ctx context.Context) (value string, ok bool) {
|
|
return getFromContext(ctx, requestIDKey)
|
|
}
|
|
|
|
func LoggerFromContext(ctx context.Context) *log.Entry {
|
|
if v := ctx.Value(loggerKey); v != nil {
|
|
return v.(*log.Entry)
|
|
}
|
|
return log.WithFields(log.Fields{})
|
|
}
|
|
|
|
type loggingResponseWriter struct {
|
|
http.ResponseWriter
|
|
size, status int
|
|
start time.Time
|
|
}
|
|
|
|
// This idea is taken from the violetear responsewriter
|
|
// https://github.com/nbari/violetear/blob/master/responsewriter.go
|
|
func newLoggingResponseWriter(w http.ResponseWriter) *loggingResponseWriter {
|
|
return &loggingResponseWriter{
|
|
ResponseWriter: w,
|
|
start: time.Now(),
|
|
status: http.StatusOK,
|
|
}
|
|
}
|
|
|
|
func (w *loggingResponseWriter) Status() int {
|
|
return w.status
|
|
}
|
|
|
|
func (w *loggingResponseWriter) Size() int {
|
|
return w.size
|
|
}
|
|
|
|
func (w *loggingResponseWriter) ElapsedTime() string {
|
|
return time.Since(w.start).String()
|
|
}
|
|
|
|
func (w *loggingResponseWriter) Write(data []byte) (int, error) {
|
|
size, err := w.ResponseWriter.Write(data)
|
|
w.size += size
|
|
return size, err
|
|
}
|
|
|
|
func (w *loggingResponseWriter) WriteHeader(statusCode int) {
|
|
w.status = statusCode
|
|
w.ResponseWriter.WriteHeader(statusCode)
|
|
}
|
|
|
|
func middleware(r *mux.Router) *mux.Router {
|
|
r.Use(loggingMw)
|
|
r.Use(discworldMw)
|
|
log.Info("using")
|
|
return r
|
|
}
|
|
|
|
func discworldMw(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("X-Clacks-Overhead", "GNU Terry Pratchett")
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func loggingMw(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
newW := newLoggingResponseWriter(w)
|
|
correlationID := CorrelationIDFromRequest(r)
|
|
if correlationID == "" {
|
|
correlationID = uuid.New().String()
|
|
}
|
|
requestID := uuid.New().String()
|
|
contextLogger := log.WithFields(log.Fields{
|
|
"requestID": requestID,
|
|
"correlationID": correlationID,
|
|
})
|
|
ctx := r.Context()
|
|
ctx = context.WithValue(ctx, requestIDKey, requestID)
|
|
ctx = context.WithValue(ctx, correlationIDKey, correlationID)
|
|
ctx = context.WithValue(ctx, loggerKey, contextLogger)
|
|
r = r.WithContext(ctx)
|
|
w.Header().Set(CorrelationIDHeader, correlationID)
|
|
w.Header().Set(RequestIDHeader, requestID)
|
|
requestLogger := contextLogger.WithFields(
|
|
log.Fields{
|
|
"path": r.URL.Path,
|
|
"query_string": r.URL.RawQuery,
|
|
"method": r.Method,
|
|
})
|
|
requestLogger.Info("Start web request")
|
|
next.ServeHTTP(newW, r)
|
|
requestLogger.WithFields(
|
|
log.Fields{
|
|
"status": newW.Status(),
|
|
"size": newW.Size(),
|
|
"duration": newW.ElapsedTime()}).
|
|
Info("End web request")
|
|
})
|
|
}
|