diff --git a/config/config.go b/config/config.go index 07a78b6..b4cd241 100644 --- a/config/config.go +++ b/config/config.go @@ -34,6 +34,8 @@ const ( KeySubscriptionNotificationSent = "sub_reminder" KeyNewsbox = "newsbox" + SessionKeyDefault = "default" + SimpleDateFormat = "2006-01-02" SimpleDateTimeFormat = "2006-01-02 15:04:05" @@ -92,6 +94,7 @@ type securityConfig struct { InsecureCookies bool `yaml:"insecure_cookies" default:"false" env:"WAKAPI_INSECURE_COOKIES"` CookieMaxAgeSec int `yaml:"cookie_max_age" default:"172800" env:"WAKAPI_COOKIE_MAX_AGE"` SecureCookie *securecookie.SecureCookie `yaml:"-"` + SessionKey []byte `yaml:"-"` } type dbConfig struct { @@ -394,6 +397,7 @@ func Load(version string) *Config { securecookie.GenerateRandomKey(64), securecookie.GenerateRandomKey(32), ) + config.Security.SessionKey = securecookie.GenerateRandomKey(32) if strings.HasSuffix(config.Server.BasePath, "/") { config.Server.BasePath = config.Server.BasePath[:len(config.Server.BasePath)-1] diff --git a/config/session.go b/config/session.go new file mode 100644 index 0000000..d721ad7 --- /dev/null +++ b/config/session.go @@ -0,0 +1,14 @@ +package config + +import "github.com/gorilla/sessions" + +// sessions are only used for displaying flash messages + +var sessionStore *sessions.CookieStore + +func GetSessionStore() *sessions.CookieStore { + if sessionStore == nil { + sessionStore = sessions.NewCookieStore(Get().Security.SessionKey) + } + return sessionStore +} diff --git a/go.mod b/go.mod index e142de1..7517934 100644 --- a/go.mod +++ b/go.mod @@ -49,6 +49,7 @@ require ( github.com/go-sql-driver/mysql v1.6.0 // indirect github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect github.com/google/uuid v1.3.0 // indirect + github.com/gorilla/sessions v1.2.1 // indirect github.com/jackc/chunkreader/v2 v2.0.1 // indirect github.com/jackc/pgconn v1.13.0 // indirect github.com/jackc/pgio v1.0.0 // indirect diff --git a/go.sum b/go.sum index 2b63822..87081cf 100644 --- a/go.sum +++ b/go.sum @@ -75,6 +75,8 @@ github.com/gorilla/schema v1.2.0 h1:YufUaxZYCKGFuAq3c96BOhjgd5nmXiOY9NGzF247Tsc= github.com/gorilla/schema v1.2.0/go.mod h1:kgLaKoK1FELgZqMAVxx/5cbj0kT+57qxUrAlIO2eleU= github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= +github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI= +github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/ianlancetaylor/demangle v0.0.0-20220319035150-800ac71e25c2/go.mod h1:aYm2/VgdVmcIU8iMfdMvDMsRAQjcfZSKFby6HOFvi/w= diff --git a/middlewares/authenticate.go b/middlewares/authenticate.go index 06bc41b..a172cda 100644 --- a/middlewares/authenticate.go +++ b/middlewares/authenticate.go @@ -22,10 +22,11 @@ var ( ) type AuthenticateMiddleware struct { - config *conf.Config - userSrvc services.IUserService - optionalForPaths []string - redirectTarget string // optional + config *conf.Config + userSrvc services.IUserService + optionalForPaths []string + redirectTarget string // optional + redirectErrorMessage string // optional } func NewAuthenticateMiddleware(userService services.IUserService) *AuthenticateMiddleware { @@ -46,6 +47,11 @@ func (m *AuthenticateMiddleware) WithRedirectTarget(path string) *AuthenticateMi return m } +func (m *AuthenticateMiddleware) WithRedirectErrorMessage(message string) *AuthenticateMiddleware { + m.redirectErrorMessage = message + return m +} + func (m *AuthenticateMiddleware) Handler(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { m.ServeHTTP(w, r, h.ServeHTTP) @@ -73,6 +79,11 @@ func (m *AuthenticateMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Reques w.WriteHeader(http.StatusUnauthorized) w.Write([]byte(conf.ErrUnauthorized)) } else { + if m.redirectErrorMessage != "" { + session, _ := conf.GetSessionStore().Get(r, conf.SessionKeyDefault) + session.AddFlash(m.redirectErrorMessage, "error") + session.Save(r, w) + } http.SetCookie(w, m.config.GetClearCookie(models.AuthCookieKey)) http.Redirect(w, r, m.redirectTarget, http.StatusFound) } diff --git a/models/view/common.go b/models/view/common.go new file mode 100644 index 0000000..af2269d --- /dev/null +++ b/models/view/common.go @@ -0,0 +1,19 @@ +package view + +type BasicViewModel interface { + SetError(string) + SetSuccess(string) +} + +type Messages struct { + Success string + Error string +} + +func (m *Messages) SetError(message string) { + m.Error = message +} + +func (m *Messages) SetSuccess(message string) { + m.Success = message +} diff --git a/models/view/home.go b/models/view/home.go index 20dda58..c991c1f 100644 --- a/models/view/home.go +++ b/models/view/home.go @@ -6,19 +6,18 @@ type Newsbox struct { } type HomeViewModel struct { - Success string - Error string + Messages TotalHours int TotalUsers int Newsbox *Newsbox } func (s *HomeViewModel) WithSuccess(m string) *HomeViewModel { - s.Success = m + s.SetSuccess(m) return s } func (s *HomeViewModel) WithError(m string) *HomeViewModel { - s.Error = m + s.SetError(m) return s } diff --git a/models/view/imprint.go b/models/view/imprint.go index c709910..fc7cf5b 100644 --- a/models/view/imprint.go +++ b/models/view/imprint.go @@ -1,18 +1,17 @@ package view type ImprintViewModel struct { + Messages HtmlText string - Success string - Error string } func (s *ImprintViewModel) WithSuccess(m string) *ImprintViewModel { - s.Success = m + s.SetSuccess(m) return s } func (s *ImprintViewModel) WithError(m string) *ImprintViewModel { - s.Error = m + s.SetError(m) return s } diff --git a/models/view/leaderboard.go b/models/view/leaderboard.go index 5ef5554..1eaeff7 100644 --- a/models/view/leaderboard.go +++ b/models/view/leaderboard.go @@ -8,6 +8,7 @@ import ( ) type LeaderboardViewModel struct { + Messages User *models.User By string Key string @@ -16,17 +17,15 @@ type LeaderboardViewModel struct { UserLanguages map[string][]string ApiKey string PageParams *utils.PageParams - Success string - Error string } func (s *LeaderboardViewModel) WithSuccess(m string) *LeaderboardViewModel { - s.Success = m + s.SetSuccess(m) return s } func (s *LeaderboardViewModel) WithError(m string) *LeaderboardViewModel { - s.Error = m + s.SetError(m) return s } diff --git a/models/view/login.go b/models/view/login.go index c7645d4..2ca5c64 100644 --- a/models/view/login.go +++ b/models/view/login.go @@ -1,8 +1,7 @@ package view type LoginViewModel struct { - Success string - Error string + Messages TotalUsers int AllowSignup bool } @@ -13,11 +12,11 @@ type SetPasswordViewModel struct { } func (s *LoginViewModel) WithSuccess(m string) *LoginViewModel { - s.Success = m + s.SetSuccess(m) return s } func (s *LoginViewModel) WithError(m string) *LoginViewModel { - s.Error = m + s.SetError(m) return s } diff --git a/models/view/settings.go b/models/view/settings.go index 3388f4c..8d607a2 100644 --- a/models/view/settings.go +++ b/models/view/settings.go @@ -6,6 +6,7 @@ import ( ) type SettingsViewModel struct { + Messages User *models.User LanguageMappings []*models.LanguageMapping Aliases []*SettingsVMCombinedAlias @@ -16,8 +17,6 @@ type SettingsViewModel struct { UserFirstData time.Time SupportContact string ApiKey string - Success string - Error string } type SettingsVMCombinedAlias struct { @@ -36,11 +35,11 @@ func (s *SettingsViewModel) SubscriptionsEnabled() bool { } func (s *SettingsViewModel) WithSuccess(m string) *SettingsViewModel { - s.Success = m + s.SetSuccess(m) return s } func (s *SettingsViewModel) WithError(m string) *SettingsViewModel { - s.Error = m + s.SetError(m) return s } diff --git a/models/view/summary.go b/models/view/summary.go index 04cd001..1bea3b4 100644 --- a/models/view/summary.go +++ b/models/view/summary.go @@ -3,6 +3,7 @@ package view import "github.com/muety/wakapi/models" type SummaryViewModel struct { + Messages *models.Summary *models.SummaryParams User *models.User @@ -10,18 +11,16 @@ type SummaryViewModel struct { EditorColors map[string]string LanguageColors map[string]string OSColors map[string]string - Error string - Success string ApiKey string RawQuery string } func (s *SummaryViewModel) WithSuccess(m string) *SummaryViewModel { - s.Success = m + s.SetSuccess(m) return s } func (s *SummaryViewModel) WithError(m string) *SummaryViewModel { - s.Error = m + s.SetError(m) return s } diff --git a/routes/home.go b/routes/home.go index e41904d..4c7f866 100644 --- a/routes/home.go +++ b/routes/home.go @@ -9,6 +9,7 @@ import ( conf "github.com/muety/wakapi/config" "github.com/muety/wakapi/models" "github.com/muety/wakapi/models/view" + routeutils "github.com/muety/wakapi/routes/utils" "github.com/muety/wakapi/services" "net/http" "strconv" @@ -46,10 +47,10 @@ func (h *HomeHandler) GetIndex(w http.ResponseWriter, r *http.Request) { return } - templates[conf.IndexTemplate].Execute(w, h.buildViewModel(r)) + templates[conf.IndexTemplate].Execute(w, h.buildViewModel(r, w)) } -func (h *HomeHandler) buildViewModel(r *http.Request) *view.HomeViewModel { +func (h *HomeHandler) buildViewModel(r *http.Request, w http.ResponseWriter) *view.HomeViewModel { var totalHours int var totalUsers int var newsbox view.Newsbox @@ -72,11 +73,10 @@ func (h *HomeHandler) buildViewModel(r *http.Request) *view.HomeViewModel { } } - return &view.HomeViewModel{ - Success: r.URL.Query().Get("success"), - Error: r.URL.Query().Get("error"), + vm := &view.HomeViewModel{ TotalHours: totalHours, TotalUsers: totalUsers, Newsbox: &newsbox, } + return routeutils.WithSessionMessages(vm, r, w) } diff --git a/routes/imprint.go b/routes/imprint.go index 3d7a9ee..7ce5e51 100644 --- a/routes/imprint.go +++ b/routes/imprint.go @@ -39,8 +39,5 @@ func (h *ImprintHandler) GetImprint(w http.ResponseWriter, r *http.Request) { } func (h *ImprintHandler) buildViewModel(r *http.Request) *view.ImprintViewModel { - return &view.ImprintViewModel{ - Success: r.URL.Query().Get("success"), - Error: r.URL.Query().Get("error"), - } + return &view.ImprintViewModel{} } diff --git a/routes/leaderboard.go b/routes/leaderboard.go index 223c838..fb43f2c 100644 --- a/routes/leaderboard.go +++ b/routes/leaderboard.go @@ -9,6 +9,7 @@ import ( "github.com/muety/wakapi/middlewares" "github.com/muety/wakapi/models" "github.com/muety/wakapi/models/view" + routeutils "github.com/muety/wakapi/routes/utils" "github.com/muety/wakapi/services" "github.com/muety/wakapi/utils" "net/http" @@ -38,6 +39,7 @@ func (h *LeaderboardHandler) RegisterRoutes(router *mux.Router) { r.Use( middlewares.NewAuthenticateMiddleware(h.userService). WithRedirectTarget(defaultErrorRedirectTarget()). + WithRedirectErrorMessage("unauthorized"). WithOptionalFor([]string{"/"}). Handler, ) @@ -48,12 +50,12 @@ func (h *LeaderboardHandler) GetIndex(w http.ResponseWriter, r *http.Request) { if h.config.IsDev() { loadTemplates() } - if err := templates[conf.LeaderboardTemplate].Execute(w, h.buildViewModel(r)); err != nil { + if err := templates[conf.LeaderboardTemplate].Execute(w, h.buildViewModel(r, w)); err != nil { logbuch.Error(err.Error()) } } -func (h *LeaderboardHandler) buildViewModel(r *http.Request) *view.LeaderboardViewModel { +func (h *LeaderboardHandler) buildViewModel(r *http.Request, w http.ResponseWriter) *view.LeaderboardViewModel { user := middlewares.GetPrincipal(r) byParam := strings.ToLower(r.URL.Query().Get("by")) keyParam := strings.ToLower(r.URL.Query().Get("key")) @@ -71,7 +73,9 @@ func (h *LeaderboardHandler) buildViewModel(r *http.Request) *view.LeaderboardVi leaderboard, err = h.leaderboardService.GetByInterval(models.IntervalPast7Days, pageParams, true) if err != nil { conf.Log().Request(r).Error("error while fetching general leaderboard items - %v", err) - return &view.LeaderboardViewModel{Error: criticalError} + return &view.LeaderboardViewModel{ + Messages: view.Messages{Error: criticalError}, + } } // regardless of page, always show own rank @@ -88,7 +92,9 @@ func (h *LeaderboardHandler) buildViewModel(r *http.Request) *view.LeaderboardVi leaderboard, err = h.leaderboardService.GetAggregatedByInterval(models.IntervalPast7Days, &by, pageParams, true) if err != nil { conf.Log().Request(r).Error("error while fetching general leaderboard items - %v", err) - return &view.LeaderboardViewModel{Error: criticalError} + return &view.LeaderboardViewModel{ + Messages: view.Messages{Error: criticalError}, + } } // regardless of page, always show own rank @@ -120,7 +126,9 @@ func (h *LeaderboardHandler) buildViewModel(r *http.Request) *view.LeaderboardVi leaderboard = leaderboard.TopByKey(by, keyParam) } } else { - return &view.LeaderboardViewModel{Error: fmt.Sprintf("unsupported aggregation '%s'", byParam)} + return &view.LeaderboardViewModel{ + Messages: view.Messages{Error: fmt.Sprintf("unsupported aggregation '%s'", byParam)}, + } } } @@ -129,7 +137,7 @@ func (h *LeaderboardHandler) buildViewModel(r *http.Request) *view.LeaderboardVi apiKey = user.ApiKey } - return &view.LeaderboardViewModel{ + vm := &view.LeaderboardViewModel{ User: user, By: byParam, Key: keyParam, @@ -138,7 +146,6 @@ func (h *LeaderboardHandler) buildViewModel(r *http.Request) *view.LeaderboardVi TopKeys: topKeys, ApiKey: apiKey, PageParams: pageParams, - Success: r.URL.Query().Get("success"), - Error: r.URL.Query().Get("error"), } + return routeutils.WithSessionMessages(vm, r, w) } diff --git a/routes/login.go b/routes/login.go index f32ca63..4b4720b 100644 --- a/routes/login.go +++ b/routes/login.go @@ -8,6 +8,7 @@ import ( "github.com/muety/wakapi/middlewares" "github.com/muety/wakapi/models" "github.com/muety/wakapi/models/view" + routeutils "github.com/muety/wakapi/routes/utils" "github.com/muety/wakapi/services" "github.com/muety/wakapi/utils" "net/http" @@ -41,6 +42,7 @@ func (h *LoginHandler) RegisterRoutes(router *mux.Router) { authMiddleware := middlewares.NewAuthenticateMiddleware(h.userSrvc). WithRedirectTarget(defaultErrorRedirectTarget()). + WithRedirectErrorMessage("unauthorized"). WithOptionalFor([]string{"/logout"}) logoutRouter := router.PathPrefix("/logout").Subrouter() @@ -58,7 +60,7 @@ func (h *LoginHandler) GetIndex(w http.ResponseWriter, r *http.Request) { return } - templates[conf.LoginTemplate].Execute(w, h.buildViewModel(r)) + templates[conf.LoginTemplate].Execute(w, h.buildViewModel(r, w)) } func (h *LoginHandler) PostLogin(w http.ResponseWriter, r *http.Request) { @@ -74,25 +76,25 @@ func (h *LoginHandler) PostLogin(w http.ResponseWriter, r *http.Request) { var login models.Login if err := r.ParseForm(); err != nil { w.WriteHeader(http.StatusBadRequest) - templates[conf.LoginTemplate].Execute(w, h.buildViewModel(r).WithError("missing parameters")) + templates[conf.LoginTemplate].Execute(w, h.buildViewModel(r, w).WithError("missing parameters")) return } if err := loginDecoder.Decode(&login, r.PostForm); err != nil { w.WriteHeader(http.StatusBadRequest) - templates[conf.LoginTemplate].Execute(w, h.buildViewModel(r).WithError("missing parameters")) + templates[conf.LoginTemplate].Execute(w, h.buildViewModel(r, w).WithError("missing parameters")) return } user, err := h.userSrvc.GetUserById(login.Username) if err != nil { w.WriteHeader(http.StatusNotFound) - templates[conf.LoginTemplate].Execute(w, h.buildViewModel(r).WithError("resource not found")) + templates[conf.LoginTemplate].Execute(w, h.buildViewModel(r, w).WithError("resource not found")) return } if !utils.CompareBcrypt(user.Password, login.Password, h.config.Security.PasswordSalt) { w.WriteHeader(http.StatusUnauthorized) - templates[conf.LoginTemplate].Execute(w, h.buildViewModel(r).WithError("invalid credentials")) + templates[conf.LoginTemplate].Execute(w, h.buildViewModel(r, w).WithError("invalid credentials")) return } @@ -100,7 +102,7 @@ func (h *LoginHandler) PostLogin(w http.ResponseWriter, r *http.Request) { if err != nil { w.WriteHeader(http.StatusInternalServerError) conf.Log().Request(r).Error("failed to encode secure cookie - %v", err) - templates[conf.LoginTemplate].Execute(w, h.buildViewModel(r).WithError("internal server error")) + templates[conf.LoginTemplate].Execute(w, h.buildViewModel(r, w).WithError("internal server error")) return } @@ -133,7 +135,7 @@ func (h *LoginHandler) GetSignup(w http.ResponseWriter, r *http.Request) { return } - templates[conf.SignupTemplate].Execute(w, h.buildViewModel(r)) + templates[conf.SignupTemplate].Execute(w, h.buildViewModel(r, w)) } func (h *LoginHandler) PostSignup(w http.ResponseWriter, r *http.Request) { @@ -143,7 +145,7 @@ func (h *LoginHandler) PostSignup(w http.ResponseWriter, r *http.Request) { if !h.config.IsDev() && !h.config.Security.AllowSignup { w.WriteHeader(http.StatusForbidden) - templates[conf.SignupTemplate].Execute(w, h.buildViewModel(r).WithError("registration is disabled on this server")) + templates[conf.SignupTemplate].Execute(w, h.buildViewModel(r, w).WithError("registration is disabled on this server")) return } @@ -155,18 +157,18 @@ func (h *LoginHandler) PostSignup(w http.ResponseWriter, r *http.Request) { var signup models.Signup if err := r.ParseForm(); err != nil { w.WriteHeader(http.StatusBadRequest) - templates[conf.SignupTemplate].Execute(w, h.buildViewModel(r).WithError("missing parameters")) + templates[conf.SignupTemplate].Execute(w, h.buildViewModel(r, w).WithError("missing parameters")) return } if err := signupDecoder.Decode(&signup, r.PostForm); err != nil { w.WriteHeader(http.StatusBadRequest) - templates[conf.SignupTemplate].Execute(w, h.buildViewModel(r).WithError("missing parameters")) + templates[conf.SignupTemplate].Execute(w, h.buildViewModel(r, w).WithError("missing parameters")) return } if !signup.IsValid() { w.WriteHeader(http.StatusBadRequest) - templates[conf.SignupTemplate].Execute(w, h.buildViewModel(r).WithError("invalid parameters")) + templates[conf.SignupTemplate].Execute(w, h.buildViewModel(r, w).WithError("invalid parameters")) return } @@ -176,23 +178,24 @@ func (h *LoginHandler) PostSignup(w http.ResponseWriter, r *http.Request) { if err != nil { w.WriteHeader(http.StatusInternalServerError) conf.Log().Request(r).Error("failed to create new user - %v", err) - templates[conf.SignupTemplate].Execute(w, h.buildViewModel(r).WithError("failed to create new user")) + templates[conf.SignupTemplate].Execute(w, h.buildViewModel(r, w).WithError("failed to create new user")) return } if !created { w.WriteHeader(http.StatusConflict) - templates[conf.SignupTemplate].Execute(w, h.buildViewModel(r).WithError("user already existing")) + templates[conf.SignupTemplate].Execute(w, h.buildViewModel(r, w).WithError("user already existing")) return } - http.Redirect(w, r, fmt.Sprintf("%s/?success=%s", h.config.Server.BasePath, "account created successfully"), http.StatusFound) + routeutils.SetSuccess(r, w, "account created successfully") + http.Redirect(w, r, h.config.Server.BasePath, http.StatusFound) } func (h *LoginHandler) GetResetPassword(w http.ResponseWriter, r *http.Request) { if h.config.IsDev() { loadTemplates() } - templates[conf.ResetPasswordTemplate].Execute(w, h.buildViewModel(r)) + templates[conf.ResetPasswordTemplate].Execute(w, h.buildViewModel(r, w)) } func (h *LoginHandler) GetSetPassword(w http.ResponseWriter, r *http.Request) { @@ -204,12 +207,12 @@ func (h *LoginHandler) GetSetPassword(w http.ResponseWriter, r *http.Request) { token := values.Get("token") if token == "" { w.WriteHeader(http.StatusUnauthorized) - templates[conf.SetPasswordTemplate].Execute(w, h.buildViewModel(r).WithError("invalid or missing token")) + templates[conf.SetPasswordTemplate].Execute(w, h.buildViewModel(r, w).WithError("invalid or missing token")) return } vm := &view.SetPasswordViewModel{ - LoginViewModel: *h.buildViewModel(r), + LoginViewModel: *h.buildViewModel(r, w), Token: token, } @@ -224,25 +227,25 @@ func (h *LoginHandler) PostSetPassword(w http.ResponseWriter, r *http.Request) { var setRequest models.SetPasswordRequest if err := r.ParseForm(); err != nil { w.WriteHeader(http.StatusBadRequest) - templates[conf.SetPasswordTemplate].Execute(w, h.buildViewModel(r).WithError("missing parameters")) + templates[conf.SetPasswordTemplate].Execute(w, h.buildViewModel(r, w).WithError("missing parameters")) return } if err := signupDecoder.Decode(&setRequest, r.PostForm); err != nil { w.WriteHeader(http.StatusBadRequest) - templates[conf.SetPasswordTemplate].Execute(w, h.buildViewModel(r).WithError("missing parameters")) + templates[conf.SetPasswordTemplate].Execute(w, h.buildViewModel(r, w).WithError("missing parameters")) return } user, err := h.userSrvc.GetUserByResetToken(setRequest.Token) if err != nil { w.WriteHeader(http.StatusUnauthorized) - templates[conf.SetPasswordTemplate].Execute(w, h.buildViewModel(r).WithError("invalid token")) + templates[conf.SetPasswordTemplate].Execute(w, h.buildViewModel(r, w).WithError("invalid token")) return } if !setRequest.IsValid() { w.WriteHeader(http.StatusBadRequest) - templates[conf.SetPasswordTemplate].Execute(w, h.buildViewModel(r).WithError("invalid parameters")) + templates[conf.SetPasswordTemplate].Execute(w, h.buildViewModel(r, w).WithError("invalid parameters")) return } @@ -251,7 +254,7 @@ func (h *LoginHandler) PostSetPassword(w http.ResponseWriter, r *http.Request) { if hash, err := utils.HashBcrypt(user.Password, h.config.Security.PasswordSalt); err != nil { w.WriteHeader(http.StatusInternalServerError) conf.Log().Request(r).Error("failed to set new password - %v", err) - templates[conf.SetPasswordTemplate].Execute(w, h.buildViewModel(r).WithError("failed to set new password")) + templates[conf.SetPasswordTemplate].Execute(w, h.buildViewModel(r, w).WithError("failed to set new password")) return } else { user.Password = hash @@ -260,11 +263,12 @@ func (h *LoginHandler) PostSetPassword(w http.ResponseWriter, r *http.Request) { if _, err := h.userSrvc.Update(user); err != nil { w.WriteHeader(http.StatusInternalServerError) conf.Log().Request(r).Error("failed to save new password - %v", err) - templates[conf.SetPasswordTemplate].Execute(w, h.buildViewModel(r).WithError("failed to save new password")) + templates[conf.SetPasswordTemplate].Execute(w, h.buildViewModel(r, w).WithError("failed to save new password")) return } - http.Redirect(w, r, fmt.Sprintf("%s/login?success=%s", h.config.Server.BasePath, "password updated successfully"), http.StatusFound) + routeutils.SetSuccess(r, w, "password updated successfully") + http.Redirect(w, r, fmt.Sprintf("%s/login", h.config.Server.BasePath), http.StatusFound) } func (h *LoginHandler) PostResetPassword(w http.ResponseWriter, r *http.Request) { @@ -274,19 +278,19 @@ func (h *LoginHandler) PostResetPassword(w http.ResponseWriter, r *http.Request) if !h.config.Mail.Enabled { w.WriteHeader(http.StatusNotImplemented) - templates[conf.ResetPasswordTemplate].Execute(w, h.buildViewModel(r).WithError("mailing is disabled on this server")) + templates[conf.ResetPasswordTemplate].Execute(w, h.buildViewModel(r, w).WithError("mailing is disabled on this server")) return } var resetRequest models.ResetPasswordRequest if err := r.ParseForm(); err != nil { w.WriteHeader(http.StatusBadRequest) - templates[conf.ResetPasswordTemplate].Execute(w, h.buildViewModel(r).WithError("missing parameters")) + templates[conf.ResetPasswordTemplate].Execute(w, h.buildViewModel(r, w).WithError("missing parameters")) return } if err := resetPasswordDecoder.Decode(&resetRequest, r.PostForm); err != nil { w.WriteHeader(http.StatusBadRequest) - templates[conf.ResetPasswordTemplate].Execute(w, h.buildViewModel(r).WithError("missing parameters")) + templates[conf.ResetPasswordTemplate].Execute(w, h.buildViewModel(r, w).WithError("missing parameters")) return } @@ -294,7 +298,7 @@ func (h *LoginHandler) PostResetPassword(w http.ResponseWriter, r *http.Request) if u, err := h.userSrvc.GenerateResetToken(user); err != nil { w.WriteHeader(http.StatusInternalServerError) conf.Log().Request(r).Error("failed to generate password reset token - %v", err) - templates[conf.ResetPasswordTemplate].Execute(w, h.buildViewModel(r).WithError("failed to generate password reset token")) + templates[conf.ResetPasswordTemplate].Execute(w, h.buildViewModel(r, w).WithError("failed to generate password reset token")) return } else { go func(user *models.User) { @@ -310,16 +314,16 @@ func (h *LoginHandler) PostResetPassword(w http.ResponseWriter, r *http.Request) conf.Log().Request(r).Warn("password reset requested for unregistered address '%s'", resetRequest.Email) } - http.Redirect(w, r, fmt.Sprintf("%s/?success=%s", h.config.Server.BasePath, "an e-mail was sent to you in case your e-mail address was registered"), http.StatusFound) + routeutils.SetSuccess(r, w, "an e-mail was sent to you in case your e-mail address was registered") + http.Redirect(w, r, h.config.Server.BasePath, http.StatusFound) } -func (h *LoginHandler) buildViewModel(r *http.Request) *view.LoginViewModel { +func (h *LoginHandler) buildViewModel(r *http.Request, w http.ResponseWriter) *view.LoginViewModel { numUsers, _ := h.userSrvc.Count() - return &view.LoginViewModel{ - Success: r.URL.Query().Get("success"), - Error: r.URL.Query().Get("error"), + vm := &view.LoginViewModel{ TotalUsers: int(numUsers), AllowSignup: h.config.IsDev() || h.config.Security.AllowSignup, } + return routeutils.WithSessionMessages(vm, r, w) } diff --git a/routes/routes.go b/routes/routes.go index 7b281b0..a060786 100644 --- a/routes/routes.go +++ b/routes/routes.go @@ -1,7 +1,6 @@ package routes import ( - "fmt" "github.com/muety/wakapi/helpers" "html/template" "net/http" @@ -105,7 +104,7 @@ func loadTemplates() { } func defaultErrorRedirectTarget() string { - return fmt.Sprintf("%s/?error=unauthorized", config.Get().Server.BasePath) + return config.Get().Server.BasePath + "/" } func add(i, j int) int { diff --git a/routes/settings.go b/routes/settings.go index d1a0448..3ef7cd8 100644 --- a/routes/settings.go +++ b/routes/settings.go @@ -69,7 +69,10 @@ func NewSettingsHandler( func (h *SettingsHandler) RegisterRoutes(router *mux.Router) { r := router.PathPrefix("/settings").Subrouter() r.Use( - middlewares.NewAuthenticateMiddleware(h.userSrvc).WithRedirectTarget(defaultErrorRedirectTarget()).Handler, + middlewares.NewAuthenticateMiddleware(h.userSrvc). + WithRedirectTarget(defaultErrorRedirectTarget()). + WithRedirectErrorMessage("unauthorized"). + Handler, ) r.Methods(http.MethodGet).HandlerFunc(h.GetIndex) r.Methods(http.MethodPost).HandlerFunc(h.PostIndex) @@ -79,7 +82,7 @@ func (h *SettingsHandler) GetIndex(w http.ResponseWriter, r *http.Request) { if h.config.IsDev() { loadTemplates() } - templates[conf.SettingsTemplate].Execute(w, h.buildViewModel(r)) + templates[conf.SettingsTemplate].Execute(w, h.buildViewModel(r, w)) } func (h *SettingsHandler) PostIndex(w http.ResponseWriter, r *http.Request) { @@ -89,7 +92,7 @@ func (h *SettingsHandler) PostIndex(w http.ResponseWriter, r *http.Request) { if err := r.ParseForm(); err != nil { w.WriteHeader(http.StatusBadRequest) - templates[conf.SettingsTemplate].Execute(w, h.buildViewModel(r).WithError("missing form values")) + templates[conf.SettingsTemplate].Execute(w, h.buildViewModel(r, w).WithError("missing form values")) return } @@ -100,7 +103,7 @@ func (h *SettingsHandler) PostIndex(w http.ResponseWriter, r *http.Request) { if actionFunc == nil { logbuch.Warn("failed to dispatch action '%s'", action) w.WriteHeader(http.StatusBadRequest) - templates[conf.SettingsTemplate].Execute(w, h.buildViewModel(r).WithError("unknown action requests")) + templates[conf.SettingsTemplate].Execute(w, h.buildViewModel(r, w).WithError("unknown action requests")) return } @@ -113,15 +116,15 @@ func (h *SettingsHandler) PostIndex(w http.ResponseWriter, r *http.Request) { if errorMsg != "" { w.WriteHeader(status) - templates[conf.SettingsTemplate].Execute(w, h.buildViewModel(r).WithError(errorMsg)) + templates[conf.SettingsTemplate].Execute(w, h.buildViewModel(r, w).WithError(errorMsg)) return } if successMsg != "" { w.WriteHeader(status) - templates[conf.SettingsTemplate].Execute(w, h.buildViewModel(r).WithSuccess(successMsg)) + templates[conf.SettingsTemplate].Execute(w, h.buildViewModel(r, w).WithSuccess(successMsg)) return } - templates[conf.SettingsTemplate].Execute(w, h.buildViewModel(r)) + templates[conf.SettingsTemplate].Execute(w, h.buildViewModel(r, w)) } func (h *SettingsHandler) dispatchAction(action string) action { @@ -622,8 +625,9 @@ func (h *SettingsHandler) actionDeleteUser(w http.ResponseWriter, r *http.Reques } }(user) + routeutils.SetSuccess(r, w, "Your account will be deleted in a few minutes. Sorry to you go.") http.SetCookie(w, h.config.GetClearCookie(models.AuthCookieKey)) - http.Redirect(w, r, fmt.Sprintf("%s/?success=%s", h.config.Server.BasePath, "Your account will be deleted in a few minutes. Sorry to you go."), http.StatusFound) + http.Redirect(w, r, h.config.Server.BasePath, http.StatusFound) return -1, "", "" } @@ -673,7 +677,7 @@ func (h *SettingsHandler) regenerateSummaries(user *models.User) error { return nil } -func (h *SettingsHandler) buildViewModel(r *http.Request) *view.SettingsViewModel { +func (h *SettingsHandler) buildViewModel(r *http.Request, w http.ResponseWriter) *view.SettingsViewModel { user := middlewares.GetPrincipal(r) // mappings @@ -683,7 +687,7 @@ func (h *SettingsHandler) buildViewModel(r *http.Request) *view.SettingsViewMode aliases, err := h.aliasSrvc.GetByUser(user.ID) if err != nil { conf.Log().Request(r).Error("error while building alias map - %v", err) - return &view.SettingsViewModel{Error: criticalError} + return &view.SettingsViewModel{Messages: view.Messages{Error: criticalError}} } aliasMap := make(map[string][]*models.Alias) for _, a := range aliases { @@ -712,7 +716,7 @@ func (h *SettingsHandler) buildViewModel(r *http.Request) *view.SettingsViewMode labelMap, err := h.projectLabelSrvc.GetByUserGroupedInverted(user.ID) if err != nil { conf.Log().Request(r).Error("error while building settings project label map - %v", err) - return &view.SettingsViewModel{Error: criticalError} + return &view.SettingsViewModel{Messages: view.Messages{Error: criticalError}} } combinedLabels := make([]*view.SettingsVMCombinedLabel, 0) @@ -734,7 +738,7 @@ func (h *SettingsHandler) buildViewModel(r *http.Request) *view.SettingsViewMode projects, err := routeutils.GetEffectiveProjectsList(user, h.heartbeatSrvc, h.aliasSrvc) if err != nil { conf.Log().Request(r).Error("error while fetching projects - %v", err) - return &view.SettingsViewModel{Error: criticalError} + return &view.SettingsViewModel{Messages: view.Messages{Error: criticalError}} } // subscriptions @@ -750,7 +754,7 @@ func (h *SettingsHandler) buildViewModel(r *http.Request) *view.SettingsViewMode firstData, _ = time.Parse(time.RFC822Z, firstDataKv.Value) } - return &view.SettingsViewModel{ + vm := &view.SettingsViewModel{ User: user, LanguageMappings: mappings, Aliases: combinedAliases, @@ -761,7 +765,6 @@ func (h *SettingsHandler) buildViewModel(r *http.Request) *view.SettingsViewMode SubscriptionPrice: subscriptionPrice, SupportContact: h.config.App.SupportContact, DataRetentionMonths: h.config.App.DataRetentionMonths, - Success: r.URL.Query().Get("success"), - Error: r.URL.Query().Get("error"), } + return routeutils.WithSessionMessages(vm, r, w) } diff --git a/routes/subscription.go b/routes/subscription.go index 574598d..5aa4fc4 100644 --- a/routes/subscription.go +++ b/routes/subscription.go @@ -9,6 +9,7 @@ import ( conf "github.com/muety/wakapi/config" "github.com/muety/wakapi/middlewares" "github.com/muety/wakapi/models" + routeutils "github.com/muety/wakapi/routes/utils" "github.com/muety/wakapi/services" "github.com/stripe/stripe-go/v74" stripePortalSession "github.com/stripe/stripe-go/v74/billingportal/session" @@ -81,7 +82,10 @@ func (h *SubscriptionHandler) RegisterRoutes(router *mux.Router) { subRouterPrivate := subRouterPublic.PathPrefix("").Subrouter() subRouterPrivate.Use( - middlewares.NewAuthenticateMiddleware(h.userSrvc).WithRedirectTarget(defaultErrorRedirectTarget()).Handler, + middlewares.NewAuthenticateMiddleware(h.userSrvc). + WithRedirectTarget(defaultErrorRedirectTarget()). + WithRedirectErrorMessage("unauthorized"). + Handler, ) subRouterPrivate.Path("/checkout").Methods(http.MethodPost).HandlerFunc(h.PostCheckout) subRouterPrivate.Path("/portal").Methods(http.MethodPost).HandlerFunc(h.PostPortal) @@ -94,12 +98,14 @@ func (h *SubscriptionHandler) PostCheckout(w http.ResponseWriter, r *http.Reques user := middlewares.GetPrincipal(r) if user.Email == "" { - http.Redirect(w, r, fmt.Sprintf("%s/settings?error=%s#subscription", h.config.Server.BasePath, "missing e-mail address"), http.StatusFound) + routeutils.SetError(r, w, "missing e-mail address") + http.Redirect(w, r, fmt.Sprintf("%s/settings#subscription", h.config.Server.BasePath), http.StatusFound) return } if err := r.ParseForm(); err != nil { - http.Redirect(w, r, fmt.Sprintf("%s/settings?error=%s#subscription", h.config.Server.BasePath, "missing form values"), http.StatusFound) + routeutils.SetError(r, w, "missing form values") + http.Redirect(w, r, fmt.Sprintf("%s/settings#subscription", h.config.Server.BasePath), http.StatusFound) return } @@ -125,7 +131,8 @@ func (h *SubscriptionHandler) PostCheckout(w http.ResponseWriter, r *http.Reques session, err := stripeCheckoutSession.New(checkoutParams) if err != nil { conf.Log().Request(r).Error("failed to create stripe checkout session: %v", err) - http.Redirect(w, r, fmt.Sprintf("%s/settings?error=%s#subscription", h.config.Server.BasePath, "something went wrong"), http.StatusFound) + routeutils.SetError(r, w, "something went wrong") + http.Redirect(w, r, fmt.Sprintf("%s/settings#subscription", h.config.Server.BasePath), http.StatusFound) return } @@ -139,7 +146,8 @@ func (h *SubscriptionHandler) PostPortal(w http.ResponseWriter, r *http.Request) user := middlewares.GetPrincipal(r) if user.StripeCustomerId == "" { - http.Redirect(w, r, fmt.Sprintf("%s/settings?error=%s#subscription", h.config.Server.BasePath, "no subscription found with your e-mail address, please contact us!"), http.StatusFound) + routeutils.SetError(r, w, "no subscription found with your e-mail address, please contact us!") + http.Redirect(w, r, fmt.Sprintf("%s/settings#subscription", h.config.Server.BasePath), http.StatusFound) return } @@ -151,7 +159,8 @@ func (h *SubscriptionHandler) PostPortal(w http.ResponseWriter, r *http.Request) session, err := stripePortalSession.New(portalParams) if err != nil { conf.Log().Request(r).Error("failed to create stripe portal session: %v", err) - http.Redirect(w, r, fmt.Sprintf("%s/settings?error=%s#subscription", h.config.Server.BasePath, "something went wrong"), http.StatusFound) + routeutils.SetError(r, w, "something went wrong") + http.Redirect(w, r, fmt.Sprintf("%s/settings#subscription", h.config.Server.BasePath), http.StatusFound) return } @@ -251,7 +260,8 @@ func (h *SubscriptionHandler) PostWebhook(w http.ResponseWriter, r *http.Request } func (h *SubscriptionHandler) GetCheckoutSuccess(w http.ResponseWriter, r *http.Request) { - http.Redirect(w, r, fmt.Sprintf("%s/settings?success=%s#subscription", h.config.Server.BasePath, "you have successfully subscribed to Wakapi!"), http.StatusFound) + routeutils.SetSuccess(r, w, "you have successfully subscribed to Wakapi!") + http.Redirect(w, r, fmt.Sprintf("%s/settings", h.config.Server.BasePath), http.StatusFound) } func (h *SubscriptionHandler) GetCheckoutCancel(w http.ResponseWriter, r *http.Request) { diff --git a/routes/summary.go b/routes/summary.go index 3b53a18..b033721 100644 --- a/routes/summary.go +++ b/routes/summary.go @@ -29,11 +29,17 @@ func NewSummaryHandler(summaryService services.ISummaryService, userService serv func (h *SummaryHandler) RegisterRoutes(router *mux.Router) { r1 := router.PathPrefix("/summary").Subrouter() - r1.Use(middlewares.NewAuthenticateMiddleware(h.userSrvc).WithRedirectTarget(defaultErrorRedirectTarget()).Handler) + r1.Use(middlewares.NewAuthenticateMiddleware(h.userSrvc). + WithRedirectTarget(defaultErrorRedirectTarget()). + WithRedirectErrorMessage("unauthorized"). + Handler) r1.Methods(http.MethodGet).HandlerFunc(h.GetIndex) r2 := router.PathPrefix("/summary").Subrouter() - r2.Use(middlewares.NewAuthenticateMiddleware(h.userSrvc).WithRedirectTarget(defaultErrorRedirectTarget()).Handler) + r2.Use(middlewares.NewAuthenticateMiddleware(h.userSrvc). + WithRedirectTarget(defaultErrorRedirectTarget()). + WithRedirectErrorMessage("unauthorized"). + Handler) r2.Methods(http.MethodGet).HandlerFunc(h.GetIndex) } @@ -64,14 +70,14 @@ func (h *SummaryHandler) GetIndex(w http.ResponseWriter, r *http.Request) { if err != nil { w.WriteHeader(status) conf.Log().Request(r).Error("failed to load summary - %v", err) - templates[conf.SummaryTemplate].Execute(w, h.buildViewModel(r).WithError(err.Error())) + templates[conf.SummaryTemplate].Execute(w, h.buildViewModel(r, w).WithError(err.Error())) return } user := middlewares.GetPrincipal(r) if user == nil { w.WriteHeader(http.StatusUnauthorized) - templates[conf.SummaryTemplate].Execute(w, h.buildViewModel(r).WithError("unauthorized")) + templates[conf.SummaryTemplate].Execute(w, h.buildViewModel(r, w).WithError("unauthorized")) return } @@ -89,9 +95,6 @@ func (h *SummaryHandler) GetIndex(w http.ResponseWriter, r *http.Request) { templates[conf.SummaryTemplate].Execute(w, vm) } -func (h *SummaryHandler) buildViewModel(r *http.Request) *view.SummaryViewModel { - return &view.SummaryViewModel{ - Success: r.URL.Query().Get("success"), - Error: r.URL.Query().Get("error"), - } +func (h *SummaryHandler) buildViewModel(r *http.Request, w http.ResponseWriter) *view.SummaryViewModel { + return su.WithSessionMessages(&view.SummaryViewModel{}, r, w) } diff --git a/routes/utils/messages.go b/routes/utils/messages.go new file mode 100644 index 0000000..0d14ab3 --- /dev/null +++ b/routes/utils/messages.go @@ -0,0 +1,33 @@ +package utils + +import ( + conf "github.com/muety/wakapi/config" + "github.com/muety/wakapi/models/view" + "net/http" +) + +func SetError(r *http.Request, w http.ResponseWriter, message string) { + setMessage(r, w, message, "error") +} + +func SetSuccess(r *http.Request, w http.ResponseWriter, message string) { + setMessage(r, w, message, "success") +} + +func WithSessionMessages[T view.BasicViewModel](vm T, r *http.Request, w http.ResponseWriter) T { + session, _ := conf.GetSessionStore().Get(r, conf.SessionKeyDefault) + if errors := session.Flashes("error"); len(errors) > 0 { + vm.SetError(errors[0].(string)) + } + if successes := session.Flashes("success"); len(successes) > 0 { + vm.SetSuccess(successes[0].(string)) + } + session.Save(r, w) + return vm +} + +func setMessage(r *http.Request, w http.ResponseWriter, message, key string) { + session, _ := conf.GetSessionStore().Get(r, conf.SessionKeyDefault) + session.AddFlash(message, key) + session.Save(r, w) +} diff --git a/views/alerts.tpl.html b/views/alerts.tpl.html index 497a027..7b85db2 100644 --- a/views/alerts.tpl.html +++ b/views/alerts.tpl.html @@ -1,13 +1,13 @@ {{ if .Error }}
- Error: {{ .Error | capitalize }} + Error: {{ .Messages.Error | capitalize }}
{{ else if .Success }}
- {{ .Success | capitalize }} + {{ .Messages.Success | capitalize }}
{{ end }} \ No newline at end of file