package web import ( "context" "net/http" "time" "github.com/google/uuid" "github.com/gorilla/mux" log "github.com/sirupsen/logrus" ) type contextKey int // context keys to be used when inspecting request contexts const ( requestIDKey contextKey = iota correlationIDKey loggerKey ) // Header names for middleware to use const ( RequestIDHeader string = "X-Request-Id" CorrelationIDHeader = "X-Correlation-Id" ) // CorrelationIDFromRequest gets a correlation id from the request, if it // exists. func CorrelationIDFromRequest(req *http.Request) string { var id string if str := req.Header.Get(CorrelationIDHeader); str != "" { id = str } return id } func getFromContext(ctx context.Context, key contextKey) (value string, ok bool) { if v := ctx.Value(key); v != nil { value = v.(string) ok = true return } value, ok = "", false return } func CorrelationIDFromContext(ctx context.Context) (value string, ok bool) { return getFromContext(ctx, correlationIDKey) } func RequestIDFromContext(ctx context.Context) (value string, ok bool) { return getFromContext(ctx, requestIDKey) } func LoggerFromContext(ctx context.Context) *log.Entry { if v := ctx.Value(loggerKey); v != nil { return v.(*log.Entry) } return log.WithFields(log.Fields{}) } type loggingResponseWriter struct { http.ResponseWriter size, status int start time.Time } // This idea is taken from the violetear responsewriter // https://github.com/nbari/violetear/blob/master/responsewriter.go func newLoggingResponseWriter(w http.ResponseWriter) *loggingResponseWriter { return &loggingResponseWriter{ ResponseWriter: w, start: time.Now(), status: http.StatusOK, } } func (w *loggingResponseWriter) Status() int { return w.status } func (w *loggingResponseWriter) Size() int { return w.size } func (w *loggingResponseWriter) ElapsedTime() string { return time.Since(w.start).String() } func (w *loggingResponseWriter) Write(data []byte) (int, error) { size, err := w.ResponseWriter.Write(data) w.size += size return size, err } func (w *loggingResponseWriter) WriteHeader(statusCode int) { w.status = statusCode w.ResponseWriter.WriteHeader(statusCode) } func middleware(r *mux.Router) *mux.Router { r.Use(loggingMw) r.Use(discworldMw) log.Info("using") return r } func discworldMw(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Clacks-Overhead", "GNU Terry Pratchett") next.ServeHTTP(w, r) }) } func loggingMw(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { newW := newLoggingResponseWriter(w) correlationID := CorrelationIDFromRequest(r) if correlationID == "" { correlationID = uuid.New().String() } requestID := uuid.New().String() contextLogger := log.WithFields(log.Fields{ "requestID": requestID, "correlationID": correlationID, }) ctx := r.Context() ctx = context.WithValue(ctx, requestIDKey, requestID) ctx = context.WithValue(ctx, correlationIDKey, correlationID) ctx = context.WithValue(ctx, loggerKey, contextLogger) r = r.WithContext(ctx) w.Header().Set(CorrelationIDHeader, correlationID) w.Header().Set(RequestIDHeader, requestID) requestLogger := contextLogger.WithFields( log.Fields{ "path": r.URL.Path, "query_string": r.URL.RawQuery, "method": r.Method, }) requestLogger.Info("Start web request") next.ServeHTTP(newW, r) requestLogger.WithFields( log.Fields{ "status": newW.Status(), "size": newW.Size(), "duration": newW.ElapsedTime()}). Info("End web request") }) }