From 9a794a82d07c254a37fab697b415bdadad5e2eb1 Mon Sep 17 00:00:00 2001 From: Lukas Schulte Pelkum Date: Sat, 17 Jun 2023 20:07:19 +0200 Subject: [PATCH] use native middleware.AllowContentType --- internal/web/request_accept.go | 14 -------------- internal/web/server.go | 7 ++++--- internal/web/v2_end_create_paste.go | 3 --- internal/web/v2_end_modify_paste.go | 3 --- internal/web/v2_end_report_paste.go | 3 --- 5 files changed, 4 insertions(+), 26 deletions(-) delete mode 100644 internal/web/request_accept.go diff --git a/internal/web/request_accept.go b/internal/web/request_accept.go deleted file mode 100644 index 07ddf30..0000000 --- a/internal/web/request_accept.go +++ /dev/null @@ -1,14 +0,0 @@ -package web - -import "net/http" - -func accept(writer http.ResponseWriter, request *http.Request, contentTypes ...string) bool { - contentType := request.Header.Get("Content-Type") - for _, accepted := range contentTypes { - if contentType == accepted { - return true - } - } - writeString(writer, http.StatusUnsupportedMediaType, "unsupported media type") - return false -} diff --git a/internal/web/server.go b/internal/web/server.go index f4f244b..b52846c 100644 --- a/internal/web/server.go +++ b/internal/web/server.go @@ -3,6 +3,7 @@ package web import ( "context" "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" "github.com/lus/pasty/internal/meta" "github.com/lus/pasty/internal/pastes" "github.com/lus/pasty/internal/reports" @@ -67,11 +68,11 @@ func (server *Server) Start() error { // Register the paste API endpoints router.Get("/api/*", router.NotFoundHandler()) router.With(server.v2MiddlewareInjectPaste).Get("/api/v2/pastes/{paste_id}", server.v2EndpointGetPaste) - router.Post("/api/v2/pastes", server.v2EndpointCreatePaste) - router.With(server.v2MiddlewareInjectPaste, server.v2MiddlewareAuthorize).Patch("/api/v2/pastes/{paste_id}", server.v2EndpointModifyPaste) + router.With(middleware.AllowContentType("application/json")).Post("/api/v2/pastes", server.v2EndpointCreatePaste) + router.With(middleware.AllowContentType("application/json"), server.v2MiddlewareInjectPaste, server.v2MiddlewareAuthorize).Patch("/api/v2/pastes/{paste_id}", server.v2EndpointModifyPaste) router.With(server.v2MiddlewareInjectPaste, server.v2MiddlewareAuthorize).Delete("/api/v2/pastes/{paste_id}", server.v2EndpointDeletePaste) if server.ReportClient != nil { - router.With(server.v2MiddlewareInjectPaste).Post("/api/v2/pastes/{paste_id}/report", server.v2EndpointReportPaste) + router.With(middleware.AllowContentType("application/json"), server.v2MiddlewareInjectPaste).Post("/api/v2/pastes/{paste_id}/report", server.v2EndpointReportPaste) } router.Get("/api/v2/info", func(writer http.ResponseWriter, request *http.Request) { writeJSONOrErr(request, writer, http.StatusOK, map[string]any{ diff --git a/internal/web/v2_end_create_paste.go b/internal/web/v2_end_create_paste.go index b552300..e114043 100644 --- a/internal/web/v2_end_create_paste.go +++ b/internal/web/v2_end_create_paste.go @@ -16,9 +16,6 @@ type v2EndpointCreatePastePayload struct { func (server *Server) v2EndpointCreatePaste(writer http.ResponseWriter, request *http.Request) { // Read, parse and validate the request payload - if !accept(writer, request, "application/json") { - return - } body, err := io.ReadAll(request.Body) if err != nil { writeErr(request, writer, err) diff --git a/internal/web/v2_end_modify_paste.go b/internal/web/v2_end_modify_paste.go index 8a6d1b7..6c58882 100644 --- a/internal/web/v2_end_modify_paste.go +++ b/internal/web/v2_end_modify_paste.go @@ -20,9 +20,6 @@ func (server *Server) v2EndpointModifyPaste(writer http.ResponseWriter, request } // Read, parse and validate the request payload - if !accept(writer, request, "application/json") { - return - } body, err := io.ReadAll(request.Body) if err != nil { writeErr(request, writer, err) diff --git a/internal/web/v2_end_report_paste.go b/internal/web/v2_end_report_paste.go index 2fb0513..b48f1fc 100644 --- a/internal/web/v2_end_report_paste.go +++ b/internal/web/v2_end_report_paste.go @@ -20,9 +20,6 @@ func (server *Server) v2EndpointReportPaste(writer http.ResponseWriter, request } // Read, parse and validate the request payload - if !accept(writer, request, "application/json") { - return - } body, err := io.ReadAll(request.Body) if err != nil { writeErr(request, writer, err)