From d1dc73b5e6af8d3632b7c759f51aea50a9279bcd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ferdinand=20M=C3=BCtsch?= Date: Sat, 6 Feb 2021 20:09:08 +0100 Subject: [PATCH] refactor: make each router handler register middleware on its own --- main.go | 13 ++++---- middlewares/authenticate.go | 41 ++++++++++++++++---------- routes/api/heartbeat.go | 6 +++- routes/api/summary.go | 8 ++++- routes/compat/shields/v1/badge.go | 1 + routes/compat/wakatime/v1/all_time.go | 11 +++++-- routes/compat/wakatime/v1/stats.go | 11 +++++-- routes/compat/wakatime/v1/summaries.go | 11 +++++-- routes/settings.go | 2 +- routes/summary.go | 2 +- 10 files changed, 74 insertions(+), 32 deletions(-) diff --git a/main.go b/main.go index effd6b5..cb0d013 100644 --- a/main.go +++ b/main.go @@ -126,13 +126,13 @@ func main() { // API Handlers healthApiHandler := api.NewHealthApiHandler(db) - heartbeatApiHandler := api.NewHeartbeatApiHandler(heartbeatService, languageMappingService) - summaryApiHandler := api.NewSummaryApiHandler(summaryService) + heartbeatApiHandler := api.NewHeartbeatApiHandler(userService, heartbeatService, languageMappingService) + summaryApiHandler := api.NewSummaryApiHandler(userService, summaryService) // Compat Handlers - wakatimeV1AllHandler := wtV1Routes.NewAllTimeHandler(summaryService) - wakatimeV1SummariesHandler := wtV1Routes.NewSummariesHandler(summaryService) - wakatimeV1StatsHandler := wtV1Routes.NewStatsHandler(summaryService) + wakatimeV1AllHandler := wtV1Routes.NewAllTimeHandler(userService, summaryService) + wakatimeV1SummariesHandler := wtV1Routes.NewSummariesHandler(userService, summaryService) + wakatimeV1StatsHandler := wtV1Routes.NewStatsHandler(userService, summaryService) shieldV1BadgeHandler := shieldsV1Routes.NewBadgeHandler(summaryService, userService) // MVC Handlers @@ -152,11 +152,10 @@ func main() { recoveryMiddleware := handlers.RecoveryHandler() loggingMiddleware := middlewares.NewLoggingMiddleware(log.New(os.Stdout, "", log.LstdFlags)) corsMiddleware := handlers.CORS() - authenticateMiddleware := middlewares.NewAuthenticateMiddleware(userService, []string{"/api/health", "/api/compat/shields/v1"}).Handler // Router configs router.Use(loggingMiddleware, recoveryMiddleware) - apiRouter.Use(corsMiddleware, authenticateMiddleware) + apiRouter.Use(corsMiddleware) // Route registrations homeHandler.RegisterRoutes(rootRouter) diff --git a/middlewares/authenticate.go b/middlewares/authenticate.go index 4939e63..311db11 100644 --- a/middlewares/authenticate.go +++ b/middlewares/authenticate.go @@ -12,19 +12,23 @@ import ( ) type AuthenticateMiddleware struct { - config *conf.Config - userSrvc services.IUserService - whitelistPaths []string + config *conf.Config + userSrvc services.IUserService + optionalForPaths []string } -func NewAuthenticateMiddleware(userService services.IUserService, whitelistPaths []string) *AuthenticateMiddleware { +func NewAuthenticateMiddleware(userService services.IUserService) *AuthenticateMiddleware { return &AuthenticateMiddleware{ - config: conf.Get(), - userSrvc: userService, - whitelistPaths: whitelistPaths, + config: conf.Get(), + userSrvc: userService, + optionalForPaths: []string{}, } } +func (m *AuthenticateMiddleware) WithOptionalFor(paths []string) { + m.optionalForPaths = paths +} + func (m *AuthenticateMiddleware) Handler(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { m.ServeHTTP(w, r, h.ServeHTTP) @@ -32,13 +36,6 @@ func (m *AuthenticateMiddleware) Handler(h http.Handler) http.Handler { } func (m *AuthenticateMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { - for _, p := range m.whitelistPaths { - if strings.HasPrefix(r.URL.Path, p) || r.URL.Path == p { - next(w, r) - return - } - } - var user *models.User user, err := m.tryGetUserByCookie(r) @@ -46,7 +43,12 @@ func (m *AuthenticateMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Reques user, err = m.tryGetUserByApiKey(r) } - if err != nil { + if err != nil || user == nil { + if m.isOptional(r.URL.Path) { + next(w, r) + return + } + if strings.HasPrefix(r.URL.Path, "/api") { w.WriteHeader(http.StatusUnauthorized) } else { @@ -60,6 +62,15 @@ func (m *AuthenticateMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Reques next(w, r.WithContext(ctx)) } +func (m *AuthenticateMiddleware) isOptional(requestPath string) bool { + for _, p := range m.optionalForPaths { + if strings.HasPrefix(requestPath, p) || requestPath == p { + return true + } + } + return false +} + func (m *AuthenticateMiddleware) tryGetUserByApiKey(r *http.Request) (*models.User, error) { key, err := utils.ExtractBearerAuth(r) if err != nil { diff --git a/routes/api/heartbeat.go b/routes/api/heartbeat.go index f44088b..56eaa44 100644 --- a/routes/api/heartbeat.go +++ b/routes/api/heartbeat.go @@ -5,6 +5,7 @@ import ( "github.com/emvi/logbuch" "github.com/gorilla/mux" conf "github.com/muety/wakapi/config" + "github.com/muety/wakapi/middlewares" customMiddleware "github.com/muety/wakapi/middlewares/custom" "github.com/muety/wakapi/services" "github.com/muety/wakapi/utils" @@ -15,13 +16,15 @@ import ( type HeartbeatApiHandler struct { config *conf.Config + userSrvc services.IUserService heartbeatSrvc services.IHeartbeatService languageMappingSrvc services.ILanguageMappingService } -func NewHeartbeatApiHandler(heartbeatService services.IHeartbeatService, languageMappingService services.ILanguageMappingService) *HeartbeatApiHandler { +func NewHeartbeatApiHandler(userService services.IUserService, heartbeatService services.IHeartbeatService, languageMappingService services.ILanguageMappingService) *HeartbeatApiHandler { return &HeartbeatApiHandler{ config: conf.Get(), + userSrvc: userService, heartbeatSrvc: heartbeatService, languageMappingSrvc: languageMappingService, } @@ -34,6 +37,7 @@ type heartbeatResponseVm struct { func (h *HeartbeatApiHandler) RegisterRoutes(router *mux.Router) { r := router.PathPrefix("/heartbeat").Subrouter() r.Use( + middlewares.NewAuthenticateMiddleware(h.userSrvc).Handler, customMiddleware.NewWakatimeRelayMiddleware().Handler, ) r.Methods(http.MethodPost).HandlerFunc(h.Post) diff --git a/routes/api/summary.go b/routes/api/summary.go index 8b39970..0e3c927 100644 --- a/routes/api/summary.go +++ b/routes/api/summary.go @@ -3,6 +3,7 @@ package api import ( "github.com/gorilla/mux" conf "github.com/muety/wakapi/config" + "github.com/muety/wakapi/middlewares" su "github.com/muety/wakapi/routes/utils" "github.com/muety/wakapi/services" "github.com/muety/wakapi/utils" @@ -11,18 +12,23 @@ import ( type SummaryApiHandler struct { config *conf.Config + userSrvc services.IUserService summarySrvc services.ISummaryService } -func NewSummaryApiHandler(summaryService services.ISummaryService) *SummaryApiHandler { +func NewSummaryApiHandler(userService services.IUserService, summaryService services.ISummaryService) *SummaryApiHandler { return &SummaryApiHandler{ summarySrvc: summaryService, + userSrvc: userService, config: conf.Get(), } } func (h *SummaryApiHandler) RegisterRoutes(router *mux.Router) { r := router.PathPrefix("/summary").Subrouter() + r.Use( + middlewares.NewAuthenticateMiddleware(h.userSrvc).Handler, + ) r.Methods(http.MethodGet).HandlerFunc(h.Get) } diff --git a/routes/compat/shields/v1/badge.go b/routes/compat/shields/v1/badge.go index 16431b4..6e43ab2 100644 --- a/routes/compat/shields/v1/badge.go +++ b/routes/compat/shields/v1/badge.go @@ -32,6 +32,7 @@ func NewBadgeHandler(summaryService services.ISummaryService, userService servic } func (h *BadgeHandler) RegisterRoutes(router *mux.Router) { + // no auth middleware here, handler itself resolves the user r := router.PathPrefix("/shields/v1/{user}").Subrouter() r.Methods(http.MethodGet).HandlerFunc(h.Get) } diff --git a/routes/compat/wakatime/v1/all_time.go b/routes/compat/wakatime/v1/all_time.go index d78e68e..89b605d 100644 --- a/routes/compat/wakatime/v1/all_time.go +++ b/routes/compat/wakatime/v1/all_time.go @@ -3,6 +3,7 @@ package v1 import ( "github.com/gorilla/mux" conf "github.com/muety/wakapi/config" + "github.com/muety/wakapi/middlewares" "github.com/muety/wakapi/models" v1 "github.com/muety/wakapi/models/compat/wakatime/v1" "github.com/muety/wakapi/services" @@ -14,18 +15,24 @@ import ( type AllTimeHandler struct { config *conf.Config + userSrvc services.IUserService summarySrvc services.ISummaryService } -func NewAllTimeHandler(summaryService services.ISummaryService) *AllTimeHandler { +func NewAllTimeHandler(userService services.IUserService, summaryService services.ISummaryService) *AllTimeHandler { return &AllTimeHandler{ + userSrvc: userService, summarySrvc: summaryService, config: conf.Get(), } } func (h *AllTimeHandler) RegisterRoutes(router *mux.Router) { - router.Path("/wakatime/v1/users/{user}/all_time_since_today").Methods(http.MethodGet).HandlerFunc(h.Get) + r := router.PathPrefix("/wakatime/v1/users/{user}/all_time_since_today").Subrouter() + r.Use( + middlewares.NewAuthenticateMiddleware(h.userSrvc).Handler, + ) + r.Methods(http.MethodGet).HandlerFunc(h.Get) } func (h *AllTimeHandler) Get(w http.ResponseWriter, r *http.Request) { diff --git a/routes/compat/wakatime/v1/stats.go b/routes/compat/wakatime/v1/stats.go index 0d38875..02925cb 100644 --- a/routes/compat/wakatime/v1/stats.go +++ b/routes/compat/wakatime/v1/stats.go @@ -4,6 +4,7 @@ import ( "errors" "github.com/gorilla/mux" conf "github.com/muety/wakapi/config" + "github.com/muety/wakapi/middlewares" "github.com/muety/wakapi/models" v1 "github.com/muety/wakapi/models/compat/wakatime/v1" "github.com/muety/wakapi/services" @@ -14,18 +15,24 @@ import ( type StatsHandler struct { config *conf.Config + userSrvc services.IUserService summarySrvc services.ISummaryService } -func NewStatsHandler(summaryService services.ISummaryService) *StatsHandler { +func NewStatsHandler(userService services.IUserService, summaryService services.ISummaryService) *StatsHandler { return &StatsHandler{ + userSrvc: userService, summarySrvc: summaryService, config: conf.Get(), } } func (h *StatsHandler) RegisterRoutes(router *mux.Router) { - router.Path("/wakatime/v1/users/{user}/stats/{range}").Methods(http.MethodGet).HandlerFunc(h.Get) + r := router.PathPrefix("/wakatime/v1/users/{user}/stats/{range}").Subrouter() + r.Use( + middlewares.NewAuthenticateMiddleware(h.userSrvc).Handler, + ) + r.Methods(http.MethodGet).HandlerFunc(h.Get) } // TODO: support filtering (requires https://github.com/muety/wakapi/issues/108) diff --git a/routes/compat/wakatime/v1/summaries.go b/routes/compat/wakatime/v1/summaries.go index 34bc800..5a331f9 100644 --- a/routes/compat/wakatime/v1/summaries.go +++ b/routes/compat/wakatime/v1/summaries.go @@ -4,6 +4,7 @@ import ( "errors" "github.com/gorilla/mux" conf "github.com/muety/wakapi/config" + "github.com/muety/wakapi/middlewares" "github.com/muety/wakapi/models" v1 "github.com/muety/wakapi/models/compat/wakatime/v1" "github.com/muety/wakapi/services" @@ -15,18 +16,24 @@ import ( type SummariesHandler struct { config *conf.Config + userSrvc services.IUserService summarySrvc services.ISummaryService } -func NewSummariesHandler(summaryService services.ISummaryService) *SummariesHandler { +func NewSummariesHandler(userService services.IUserService, summaryService services.ISummaryService) *SummariesHandler { return &SummariesHandler{ + userSrvc: userService, summarySrvc: summaryService, config: conf.Get(), } } func (h *SummariesHandler) RegisterRoutes(router *mux.Router) { - router.Path("/wakatime/v1/users/{user}/summaries").Methods(http.MethodGet).HandlerFunc(h.Get) + r := router.PathPrefix("/wakatime/v1/users/{user}/summaries").Subrouter() + r.Use( + middlewares.NewAuthenticateMiddleware(h.userSrvc).Handler, + ) + r.Methods(http.MethodGet).HandlerFunc(h.Get) } // TODO: Support parameters: project, branches, timeout, writes_only, timezone diff --git a/routes/settings.go b/routes/settings.go index 2de69c8..b3a6f2a 100644 --- a/routes/settings.go +++ b/routes/settings.go @@ -57,7 +57,7 @@ func NewSettingsHandler( func (h *SettingsHandler) RegisterRoutes(router *mux.Router) { r := router.PathPrefix("/settings").Subrouter() r.Use( - middlewares.NewAuthenticateMiddleware(h.userSrvc, []string{}).Handler, + middlewares.NewAuthenticateMiddleware(h.userSrvc).Handler, ) r.Methods(http.MethodGet).HandlerFunc(h.GetIndex) r.Methods(http.MethodPost).HandlerFunc(h.PostIndex) diff --git a/routes/summary.go b/routes/summary.go index 42ec1f3..fcf8dd5 100644 --- a/routes/summary.go +++ b/routes/summary.go @@ -29,7 +29,7 @@ func NewSummaryHandler(summaryService services.ISummaryService, userService serv func (h *SummaryHandler) RegisterRoutes(router *mux.Router) { r := router.PathPrefix("/summary").Subrouter() r.Use( - middlewares.NewAuthenticateMiddleware(h.userSrvc, []string{}).Handler, + middlewares.NewAuthenticateMiddleware(h.userSrvc).Handler, ) r.Methods(http.MethodGet).HandlerFunc(h.GetIndex) }