diff --git a/internal/web/frontend_server.go b/internal/web/frontend_server.go index 469b30a..d108f50 100644 --- a/internal/web/frontend_server.go +++ b/internal/web/frontend_server.go @@ -31,7 +31,7 @@ func frontendHandler(notFoundHandler http.HandlerFunc) http.HandlerFunc { } return } - writeErr(writer, err) + writeErr(request, writer, err) return } defer func() { @@ -40,7 +40,7 @@ func frontendHandler(notFoundHandler http.HandlerFunc) http.HandlerFunc { fileInfo, err := file.Stat() if err != nil { - writeErr(writer, err) + writeErr(request, writer, err) return } @@ -55,7 +55,7 @@ func frontendHandler(notFoundHandler http.HandlerFunc) http.HandlerFunc { content, err := io.ReadAll(file) if err != nil { - writeErr(writer, err) + writeErr(request, writer, err) return } @@ -65,10 +65,10 @@ func frontendHandler(notFoundHandler http.HandlerFunc) http.HandlerFunc { } } -func serveIndexFile(writer http.ResponseWriter, _ *http.Request) { +func serveIndexFile(writer http.ResponseWriter, request *http.Request) { indexFile, err := frontend.ReadFile("frontend/index.html") if err != nil { - writeErr(writer, err) + writeErr(request, writer, err) return } writer.Header().Set("Content-Type", "text/html") diff --git a/internal/web/response_writer.go b/internal/web/response_writer.go index 9f3a947..accaee9 100644 --- a/internal/web/response_writer.go +++ b/internal/web/response_writer.go @@ -2,10 +2,12 @@ package web import ( "encoding/json" + "github.com/lus/pasty/pkg/chizerolog" "net/http" ) -func writeErr(writer http.ResponseWriter, err error) { +func writeErr(request *http.Request, writer http.ResponseWriter, err error) { + chizerolog.InjectError(request, err) writeString(writer, http.StatusInternalServerError, err.Error()) } @@ -26,8 +28,8 @@ func writeJSON(writer http.ResponseWriter, status int, value any) error { return nil } -func writeJSONOrErr(writer http.ResponseWriter, status int, value any) { +func writeJSONOrErr(request *http.Request, writer http.ResponseWriter, status int, value any) { if err := writeJSON(writer, status, value); err != nil { - writeErr(writer, err) + writeErr(request, writer, err) } } diff --git a/internal/web/server.go b/internal/web/server.go index a95eeb2..da3434b 100644 --- a/internal/web/server.go +++ b/internal/web/server.go @@ -7,6 +7,7 @@ import ( "github.com/lus/pasty/internal/pastes" "github.com/lus/pasty/internal/reports" "github.com/lus/pasty/internal/storage" + "github.com/lus/pasty/pkg/chizerolog" "net/http" ) @@ -49,6 +50,9 @@ type Server struct { func (server *Server) Start() error { router := chi.NewRouter() + router.Use(chizerolog.Logger) + router.Use(chizerolog.Recover) + // Register the web frontend handler router.Get("/*", frontendHandler(router.NotFoundHandler())) @@ -72,7 +76,7 @@ func (server *Server) Start() error { router.With(server.v2MiddlewareInjectPaste).Post("/api/v2/pastes/{paste_id}/report", server.v2EndpointReportPaste) } router.Get("/api/v2/info", func(writer http.ResponseWriter, request *http.Request) { - writeJSONOrErr(writer, http.StatusOK, map[string]any{ + writeJSONOrErr(request, writer, http.StatusOK, map[string]any{ "version": meta.Version, "modificationTokens": server.ModificationTokensEnabled, "reports": server.ReportClient != nil, diff --git a/internal/web/v2_end_create_paste.go b/internal/web/v2_end_create_paste.go index 4ac0cb4..e114043 100644 --- a/internal/web/v2_end_create_paste.go +++ b/internal/web/v2_end_create_paste.go @@ -18,12 +18,12 @@ func (server *Server) v2EndpointCreatePaste(writer http.ResponseWriter, request // Read, parse and validate the request payload body, err := io.ReadAll(request.Body) if err != nil { - writeErr(writer, err) + writeErr(request, writer, err) return } payload := new(v2EndpointCreatePastePayload) if err := json.Unmarshal(body, payload); err != nil { - writeErr(writer, err) + writeErr(request, writer, err) return } if payload.Content == "" { @@ -37,7 +37,7 @@ func (server *Server) v2EndpointCreatePaste(writer http.ResponseWriter, request id, err := pastes.GenerateID(request.Context(), server.Storage.Pastes(), server.PasteIDCharset, server.PasteIDLength) if err != nil { - writeErr(writer, err) + writeErr(request, writer, err) return } @@ -54,17 +54,17 @@ func (server *Server) v2EndpointCreatePaste(writer http.ResponseWriter, request paste.ModificationToken = modificationToken if err := paste.HashModificationToken(); err != nil { - writeErr(writer, err) + writeErr(request, writer, err) return } } if err := server.Storage.Pastes().Upsert(request.Context(), paste); err != nil { - writeErr(writer, err) + writeErr(request, writer, err) return } cpy := *paste cpy.ModificationToken = modificationToken - writeJSONOrErr(writer, http.StatusCreated, cpy) + writeJSONOrErr(request, writer, http.StatusCreated, cpy) } diff --git a/internal/web/v2_end_delete_paste.go b/internal/web/v2_end_delete_paste.go index f5e6c1b..09926c0 100644 --- a/internal/web/v2_end_delete_paste.go +++ b/internal/web/v2_end_delete_paste.go @@ -13,6 +13,7 @@ func (server *Server) v2EndpointDeletePaste(writer http.ResponseWriter, request } if err := server.Storage.Pastes().DeleteByID(request.Context(), paste.ID); err != nil { - writeErr(writer, err) + writeErr(request, writer, err) } + writer.WriteHeader(http.StatusOK) } diff --git a/internal/web/v2_end_get_paste.go b/internal/web/v2_end_get_paste.go index 5be338e..de18c82 100644 --- a/internal/web/v2_end_get_paste.go +++ b/internal/web/v2_end_get_paste.go @@ -1,6 +1,7 @@ package web import ( + "errors" "github.com/lus/pasty/internal/pastes" "net/http" ) @@ -8,11 +9,11 @@ import ( func (server *Server) v2EndpointGetPaste(writer http.ResponseWriter, request *http.Request) { paste, ok := request.Context().Value("paste").(*pastes.Paste) if !ok { - writeString(writer, http.StatusInternalServerError, "missing paste object") + writeErr(request, writer, errors.New("missing paste object")) return } cpy := *paste cpy.ModificationToken = "" - writeJSONOrErr(writer, http.StatusOK, cpy) + writeJSONOrErr(request, writer, http.StatusOK, cpy) } diff --git a/internal/web/v2_end_modify_paste.go b/internal/web/v2_end_modify_paste.go index ad648ea..6c58882 100644 --- a/internal/web/v2_end_modify_paste.go +++ b/internal/web/v2_end_modify_paste.go @@ -22,12 +22,12 @@ func (server *Server) v2EndpointModifyPaste(writer http.ResponseWriter, request // Read, parse and validate the request payload body, err := io.ReadAll(request.Body) if err != nil { - writeErr(writer, err) + writeErr(request, writer, err) return } payload := new(v2EndpointModifyPastePayload) if err := json.Unmarshal(body, payload); err != nil { - writeErr(writer, err) + writeErr(request, writer, err) return } if payload.Content != nil && *payload.Content == "" { @@ -55,6 +55,6 @@ func (server *Server) v2EndpointModifyPaste(writer http.ResponseWriter, request // Save the modified paste if err := server.Storage.Pastes().Upsert(request.Context(), paste); err != nil { - writeErr(writer, err) + writeErr(request, writer, err) } } diff --git a/internal/web/v2_end_report_paste.go b/internal/web/v2_end_report_paste.go index 186c88e..b48f1fc 100644 --- a/internal/web/v2_end_report_paste.go +++ b/internal/web/v2_end_report_paste.go @@ -22,12 +22,12 @@ func (server *Server) v2EndpointReportPaste(writer http.ResponseWriter, request // Read, parse and validate the request payload body, err := io.ReadAll(request.Body) if err != nil { - writeErr(writer, err) + writeErr(request, writer, err) return } payload := new(v2EndpointReportPastePayload) if err := json.Unmarshal(body, payload); err != nil { - writeErr(writer, err) + writeErr(request, writer, err) return } if payload.Reason == "" { @@ -41,8 +41,8 @@ func (server *Server) v2EndpointReportPaste(writer http.ResponseWriter, request } response, err := server.ReportClient.Send(report) if err != nil { - writeErr(writer, err) + writeErr(request, writer, err) return } - writeJSONOrErr(writer, http.StatusOK, response) + writeJSONOrErr(request, writer, http.StatusOK, response) } diff --git a/internal/web/v2_mid_inject_paste.go b/internal/web/v2_mid_inject_paste.go index 399a1f2..26657f9 100644 --- a/internal/web/v2_mid_inject_paste.go +++ b/internal/web/v2_mid_inject_paste.go @@ -17,10 +17,7 @@ func (server *Server) v2MiddlewareInjectPaste(next http.Handler) http.Handler { paste, err := server.Storage.Pastes().FindByID(request.Context(), pasteID) if err != nil { - if pasteID == "" { - writeErr(writer, err) - return - } + writeErr(request, writer, err) } if paste == nil { writeString(writer, http.StatusNotFound, "paste not found") diff --git a/pkg/chizerolog/logger.go b/pkg/chizerolog/logger.go new file mode 100644 index 0000000..2246ab0 --- /dev/null +++ b/pkg/chizerolog/logger.go @@ -0,0 +1,78 @@ +package chizerolog + +import ( + "context" + "fmt" + "github.com/go-chi/chi/v5/middleware" + "github.com/rs/zerolog/log" + "net/http" + "time" +) + +const dataKey = "chzl_meta" + +// Logger uses the global zerolog logger to log HTTP requests. +// Log messages are printed with the debug level. +// This middleware should be registered first. +func Logger(next http.Handler) http.Handler { + fn := func(writer http.ResponseWriter, request *http.Request) { + request = request.WithContext(context.WithValue(request.Context(), dataKey, make(map[string]any))) + + proxy := middleware.NewWrapResponseWriter(writer, request.ProtoMajor) + + start := time.Now() + defer func() { + end := time.Now() + + scheme := "http" + if request.TLS != nil { + scheme = "https" + } + url := fmt.Sprintf("%s://%s%s", scheme, request.Host, request.RequestURI) + + var err error + data := request.Context().Value(dataKey) + if data != nil { + injErr, ok := data.(map[string]any)["err"] + if ok { + err = injErr.(error) + } + } + + if err == nil { + log.Debug(). + Str("proto", request.Proto). + Str("method", request.Method). + Str("route", url). + Str("client_address", request.RemoteAddr). + Int("response_size", proxy.BytesWritten()). + Str("elapsed", fmt.Sprintf("%s", end.Sub(start))). + Int("status_code", proxy.Status()). + Msg("An incoming request has been processed.") + } else { + log.Error(). + Err(err). + Str("proto", request.Proto). + Str("method", request.Method). + Str("route", url). + Str("client_address", request.RemoteAddr). + Int("response_size", proxy.BytesWritten()). + Str("elapsed", fmt.Sprintf("%s", end.Sub(start))). + Int("status_code", proxy.Status()). + Msg("An incoming request has been processed and resulted in an unexpected error.") + } + }() + + next.ServeHTTP(proxy, request) + } + return http.HandlerFunc(fn) +} + +// InjectError injects the given error to a specific key so that Logger will log its occurrence later on in the request chain. +func InjectError(request *http.Request, err error) { + data := request.Context().Value(dataKey) + if data == nil { + return + } + data.(map[string]any)["err"] = err +} diff --git a/pkg/chizerolog/recoverer.go b/pkg/chizerolog/recoverer.go new file mode 100644 index 0000000..0bd4965 --- /dev/null +++ b/pkg/chizerolog/recoverer.go @@ -0,0 +1,36 @@ +package chizerolog + +import ( + "fmt" + "github.com/rs/zerolog/log" + "net/http" + "runtime/debug" +) + +// Recover recovers any call to panic() made by a request handler or middleware. +// It also logs an error-levelled message using the global zerolog logger. +// This middleware should be registered first (or second if Logger is also used). +func Recover(next http.Handler) http.Handler { + fn := func(writer http.ResponseWriter, request *http.Request) { + defer func() { + scheme := "http" + if request.TLS != nil { + scheme = "https" + } + url := fmt.Sprintf("%s://%s%s", scheme, request.Host, request.RequestURI) + + if rec := recover(); rec != nil { + log.Error(). + Str("proto", request.Proto). + Str("method", request.Method). + Str("route", url). + Interface("recovered", rec). + Bytes("stack", debug.Stack()). + Msg("A request handler has panicked.") + http.Error(writer, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + } + }() + next.ServeHTTP(writer, request) + } + return http.HandlerFunc(fn) +}