diff --git a/internal/web/request_accept.go b/internal/web/request_accept.go new file mode 100644 index 0000000..07ddf30 --- /dev/null +++ b/internal/web/request_accept.go @@ -0,0 +1,14 @@ +package web + +import "net/http" + +func accept(writer http.ResponseWriter, request *http.Request, contentTypes ...string) bool { + contentType := request.Header.Get("Content-Type") + for _, accepted := range contentTypes { + if contentType == accepted { + return true + } + } + writeString(writer, http.StatusUnsupportedMediaType, "unsupported media type") + return false +} diff --git a/internal/web/v2_end_create_paste.go b/internal/web/v2_end_create_paste.go index e114043..b552300 100644 --- a/internal/web/v2_end_create_paste.go +++ b/internal/web/v2_end_create_paste.go @@ -16,6 +16,9 @@ type v2EndpointCreatePastePayload struct { func (server *Server) v2EndpointCreatePaste(writer http.ResponseWriter, request *http.Request) { // Read, parse and validate the request payload + if !accept(writer, request, "application/json") { + return + } body, err := io.ReadAll(request.Body) if err != nil { writeErr(request, writer, err) diff --git a/internal/web/v2_end_modify_paste.go b/internal/web/v2_end_modify_paste.go index 6c58882..8a6d1b7 100644 --- a/internal/web/v2_end_modify_paste.go +++ b/internal/web/v2_end_modify_paste.go @@ -20,6 +20,9 @@ func (server *Server) v2EndpointModifyPaste(writer http.ResponseWriter, request } // Read, parse and validate the request payload + if !accept(writer, request, "application/json") { + return + } body, err := io.ReadAll(request.Body) if err != nil { writeErr(request, writer, err) diff --git a/internal/web/v2_end_report_paste.go b/internal/web/v2_end_report_paste.go index b48f1fc..2fb0513 100644 --- a/internal/web/v2_end_report_paste.go +++ b/internal/web/v2_end_report_paste.go @@ -20,6 +20,9 @@ func (server *Server) v2EndpointReportPaste(writer http.ResponseWriter, request } // Read, parse and validate the request payload + if !accept(writer, request, "application/json") { + return + } body, err := io.ReadAll(request.Body) if err != nil { writeErr(request, writer, err)