refactor: flash messages framework (resolve #446)

This commit is contained in:
Ferdinand Mütsch 2023-01-02 18:05:28 +01:00
parent a1444bca8c
commit 746608c062
22 changed files with 214 additions and 113 deletions

View File

@ -34,6 +34,8 @@ const (
KeySubscriptionNotificationSent = "sub_reminder"
KeyNewsbox = "newsbox"
SessionKeyDefault = "default"
SimpleDateFormat = "2006-01-02"
SimpleDateTimeFormat = "2006-01-02 15:04:05"
@ -92,6 +94,7 @@ type securityConfig struct {
InsecureCookies bool `yaml:"insecure_cookies" default:"false" env:"WAKAPI_INSECURE_COOKIES"`
CookieMaxAgeSec int `yaml:"cookie_max_age" default:"172800" env:"WAKAPI_COOKIE_MAX_AGE"`
SecureCookie *securecookie.SecureCookie `yaml:"-"`
SessionKey []byte `yaml:"-"`
}
type dbConfig struct {
@ -394,6 +397,7 @@ func Load(version string) *Config {
securecookie.GenerateRandomKey(64),
securecookie.GenerateRandomKey(32),
)
config.Security.SessionKey = securecookie.GenerateRandomKey(32)
if strings.HasSuffix(config.Server.BasePath, "/") {
config.Server.BasePath = config.Server.BasePath[:len(config.Server.BasePath)-1]

14
config/session.go Normal file
View File

@ -0,0 +1,14 @@
package config
import "github.com/gorilla/sessions"
// sessions are only used for displaying flash messages
var sessionStore *sessions.CookieStore
func GetSessionStore() *sessions.CookieStore {
if sessionStore == nil {
sessionStore = sessions.NewCookieStore(Get().Security.SessionKey)
}
return sessionStore
}

1
go.mod
View File

@ -49,6 +49,7 @@ require (
github.com/go-sql-driver/mysql v1.6.0 // indirect
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/gorilla/sessions v1.2.1 // indirect
github.com/jackc/chunkreader/v2 v2.0.1 // indirect
github.com/jackc/pgconn v1.13.0 // indirect
github.com/jackc/pgio v1.0.0 // indirect

2
go.sum
View File

@ -75,6 +75,8 @@ github.com/gorilla/schema v1.2.0 h1:YufUaxZYCKGFuAq3c96BOhjgd5nmXiOY9NGzF247Tsc=
github.com/gorilla/schema v1.2.0/go.mod h1:kgLaKoK1FELgZqMAVxx/5cbj0kT+57qxUrAlIO2eleU=
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI=
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc=
github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4=
github.com/ianlancetaylor/demangle v0.0.0-20220319035150-800ac71e25c2/go.mod h1:aYm2/VgdVmcIU8iMfdMvDMsRAQjcfZSKFby6HOFvi/w=

View File

@ -22,10 +22,11 @@ var (
)
type AuthenticateMiddleware struct {
config *conf.Config
userSrvc services.IUserService
optionalForPaths []string
redirectTarget string // optional
config *conf.Config
userSrvc services.IUserService
optionalForPaths []string
redirectTarget string // optional
redirectErrorMessage string // optional
}
func NewAuthenticateMiddleware(userService services.IUserService) *AuthenticateMiddleware {
@ -46,6 +47,11 @@ func (m *AuthenticateMiddleware) WithRedirectTarget(path string) *AuthenticateMi
return m
}
func (m *AuthenticateMiddleware) WithRedirectErrorMessage(message string) *AuthenticateMiddleware {
m.redirectErrorMessage = message
return m
}
func (m *AuthenticateMiddleware) Handler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
m.ServeHTTP(w, r, h.ServeHTTP)
@ -73,6 +79,11 @@ func (m *AuthenticateMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Reques
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(conf.ErrUnauthorized))
} else {
if m.redirectErrorMessage != "" {
session, _ := conf.GetSessionStore().Get(r, conf.SessionKeyDefault)
session.AddFlash(m.redirectErrorMessage, "error")
session.Save(r, w)
}
http.SetCookie(w, m.config.GetClearCookie(models.AuthCookieKey))
http.Redirect(w, r, m.redirectTarget, http.StatusFound)
}

19
models/view/common.go Normal file
View File

@ -0,0 +1,19 @@
package view
type BasicViewModel interface {
SetError(string)
SetSuccess(string)
}
type Messages struct {
Success string
Error string
}
func (m *Messages) SetError(message string) {
m.Error = message
}
func (m *Messages) SetSuccess(message string) {
m.Success = message
}

View File

@ -6,19 +6,18 @@ type Newsbox struct {
}
type HomeViewModel struct {
Success string
Error string
Messages
TotalHours int
TotalUsers int
Newsbox *Newsbox
}
func (s *HomeViewModel) WithSuccess(m string) *HomeViewModel {
s.Success = m
s.SetSuccess(m)
return s
}
func (s *HomeViewModel) WithError(m string) *HomeViewModel {
s.Error = m
s.SetError(m)
return s
}

View File

@ -1,18 +1,17 @@
package view
type ImprintViewModel struct {
Messages
HtmlText string
Success string
Error string
}
func (s *ImprintViewModel) WithSuccess(m string) *ImprintViewModel {
s.Success = m
s.SetSuccess(m)
return s
}
func (s *ImprintViewModel) WithError(m string) *ImprintViewModel {
s.Error = m
s.SetError(m)
return s
}

View File

@ -8,6 +8,7 @@ import (
)
type LeaderboardViewModel struct {
Messages
User *models.User
By string
Key string
@ -16,17 +17,15 @@ type LeaderboardViewModel struct {
UserLanguages map[string][]string
ApiKey string
PageParams *utils.PageParams
Success string
Error string
}
func (s *LeaderboardViewModel) WithSuccess(m string) *LeaderboardViewModel {
s.Success = m
s.SetSuccess(m)
return s
}
func (s *LeaderboardViewModel) WithError(m string) *LeaderboardViewModel {
s.Error = m
s.SetError(m)
return s
}

View File

@ -1,8 +1,7 @@
package view
type LoginViewModel struct {
Success string
Error string
Messages
TotalUsers int
AllowSignup bool
}
@ -13,11 +12,11 @@ type SetPasswordViewModel struct {
}
func (s *LoginViewModel) WithSuccess(m string) *LoginViewModel {
s.Success = m
s.SetSuccess(m)
return s
}
func (s *LoginViewModel) WithError(m string) *LoginViewModel {
s.Error = m
s.SetError(m)
return s
}

View File

@ -6,6 +6,7 @@ import (
)
type SettingsViewModel struct {
Messages
User *models.User
LanguageMappings []*models.LanguageMapping
Aliases []*SettingsVMCombinedAlias
@ -16,8 +17,6 @@ type SettingsViewModel struct {
UserFirstData time.Time
SupportContact string
ApiKey string
Success string
Error string
}
type SettingsVMCombinedAlias struct {
@ -36,11 +35,11 @@ func (s *SettingsViewModel) SubscriptionsEnabled() bool {
}
func (s *SettingsViewModel) WithSuccess(m string) *SettingsViewModel {
s.Success = m
s.SetSuccess(m)
return s
}
func (s *SettingsViewModel) WithError(m string) *SettingsViewModel {
s.Error = m
s.SetError(m)
return s
}

View File

@ -3,6 +3,7 @@ package view
import "github.com/muety/wakapi/models"
type SummaryViewModel struct {
Messages
*models.Summary
*models.SummaryParams
User *models.User
@ -10,18 +11,16 @@ type SummaryViewModel struct {
EditorColors map[string]string
LanguageColors map[string]string
OSColors map[string]string
Error string
Success string
ApiKey string
RawQuery string
}
func (s *SummaryViewModel) WithSuccess(m string) *SummaryViewModel {
s.Success = m
s.SetSuccess(m)
return s
}
func (s *SummaryViewModel) WithError(m string) *SummaryViewModel {
s.Error = m
s.SetError(m)
return s
}

View File

@ -9,6 +9,7 @@ import (
conf "github.com/muety/wakapi/config"
"github.com/muety/wakapi/models"
"github.com/muety/wakapi/models/view"
routeutils "github.com/muety/wakapi/routes/utils"
"github.com/muety/wakapi/services"
"net/http"
"strconv"
@ -46,10 +47,10 @@ func (h *HomeHandler) GetIndex(w http.ResponseWriter, r *http.Request) {
return
}
templates[conf.IndexTemplate].Execute(w, h.buildViewModel(r))
templates[conf.IndexTemplate].Execute(w, h.buildViewModel(r, w))
}
func (h *HomeHandler) buildViewModel(r *http.Request) *view.HomeViewModel {
func (h *HomeHandler) buildViewModel(r *http.Request, w http.ResponseWriter) *view.HomeViewModel {
var totalHours int
var totalUsers int
var newsbox view.Newsbox
@ -72,11 +73,10 @@ func (h *HomeHandler) buildViewModel(r *http.Request) *view.HomeViewModel {
}
}
return &view.HomeViewModel{
Success: r.URL.Query().Get("success"),
Error: r.URL.Query().Get("error"),
vm := &view.HomeViewModel{
TotalHours: totalHours,
TotalUsers: totalUsers,
Newsbox: &newsbox,
}
return routeutils.WithSessionMessages(vm, r, w)
}

View File

@ -39,8 +39,5 @@ func (h *ImprintHandler) GetImprint(w http.ResponseWriter, r *http.Request) {
}
func (h *ImprintHandler) buildViewModel(r *http.Request) *view.ImprintViewModel {
return &view.ImprintViewModel{
Success: r.URL.Query().Get("success"),
Error: r.URL.Query().Get("error"),
}
return &view.ImprintViewModel{}
}

View File

@ -9,6 +9,7 @@ import (
"github.com/muety/wakapi/middlewares"
"github.com/muety/wakapi/models"
"github.com/muety/wakapi/models/view"
routeutils "github.com/muety/wakapi/routes/utils"
"github.com/muety/wakapi/services"
"github.com/muety/wakapi/utils"
"net/http"
@ -38,6 +39,7 @@ func (h *LeaderboardHandler) RegisterRoutes(router *mux.Router) {
r.Use(
middlewares.NewAuthenticateMiddleware(h.userService).
WithRedirectTarget(defaultErrorRedirectTarget()).
WithRedirectErrorMessage("unauthorized").
WithOptionalFor([]string{"/"}).
Handler,
)
@ -48,12 +50,12 @@ func (h *LeaderboardHandler) GetIndex(w http.ResponseWriter, r *http.Request) {
if h.config.IsDev() {
loadTemplates()
}
if err := templates[conf.LeaderboardTemplate].Execute(w, h.buildViewModel(r)); err != nil {
if err := templates[conf.LeaderboardTemplate].Execute(w, h.buildViewModel(r, w)); err != nil {
logbuch.Error(err.Error())
}
}
func (h *LeaderboardHandler) buildViewModel(r *http.Request) *view.LeaderboardViewModel {
func (h *LeaderboardHandler) buildViewModel(r *http.Request, w http.ResponseWriter) *view.LeaderboardViewModel {
user := middlewares.GetPrincipal(r)
byParam := strings.ToLower(r.URL.Query().Get("by"))
keyParam := strings.ToLower(r.URL.Query().Get("key"))
@ -71,7 +73,9 @@ func (h *LeaderboardHandler) buildViewModel(r *http.Request) *view.LeaderboardVi
leaderboard, err = h.leaderboardService.GetByInterval(models.IntervalPast7Days, pageParams, true)
if err != nil {
conf.Log().Request(r).Error("error while fetching general leaderboard items - %v", err)
return &view.LeaderboardViewModel{Error: criticalError}
return &view.LeaderboardViewModel{
Messages: view.Messages{Error: criticalError},
}
}
// regardless of page, always show own rank
@ -88,7 +92,9 @@ func (h *LeaderboardHandler) buildViewModel(r *http.Request) *view.LeaderboardVi
leaderboard, err = h.leaderboardService.GetAggregatedByInterval(models.IntervalPast7Days, &by, pageParams, true)
if err != nil {
conf.Log().Request(r).Error("error while fetching general leaderboard items - %v", err)
return &view.LeaderboardViewModel{Error: criticalError}
return &view.LeaderboardViewModel{
Messages: view.Messages{Error: criticalError},
}
}
// regardless of page, always show own rank
@ -120,7 +126,9 @@ func (h *LeaderboardHandler) buildViewModel(r *http.Request) *view.LeaderboardVi
leaderboard = leaderboard.TopByKey(by, keyParam)
}
} else {
return &view.LeaderboardViewModel{Error: fmt.Sprintf("unsupported aggregation '%s'", byParam)}
return &view.LeaderboardViewModel{
Messages: view.Messages{Error: fmt.Sprintf("unsupported aggregation '%s'", byParam)},
}
}
}
@ -129,7 +137,7 @@ func (h *LeaderboardHandler) buildViewModel(r *http.Request) *view.LeaderboardVi
apiKey = user.ApiKey
}
return &view.LeaderboardViewModel{
vm := &view.LeaderboardViewModel{
User: user,
By: byParam,
Key: keyParam,
@ -138,7 +146,6 @@ func (h *LeaderboardHandler) buildViewModel(r *http.Request) *view.LeaderboardVi
TopKeys: topKeys,
ApiKey: apiKey,
PageParams: pageParams,
Success: r.URL.Query().Get("success"),
Error: r.URL.Query().Get("error"),
}
return routeutils.WithSessionMessages(vm, r, w)
}

View File

@ -8,6 +8,7 @@ import (
"github.com/muety/wakapi/middlewares"
"github.com/muety/wakapi/models"
"github.com/muety/wakapi/models/view"
routeutils "github.com/muety/wakapi/routes/utils"
"github.com/muety/wakapi/services"
"github.com/muety/wakapi/utils"
"net/http"
@ -41,6 +42,7 @@ func (h *LoginHandler) RegisterRoutes(router *mux.Router) {
authMiddleware := middlewares.NewAuthenticateMiddleware(h.userSrvc).
WithRedirectTarget(defaultErrorRedirectTarget()).
WithRedirectErrorMessage("unauthorized").
WithOptionalFor([]string{"/logout"})
logoutRouter := router.PathPrefix("/logout").Subrouter()
@ -58,7 +60,7 @@ func (h *LoginHandler) GetIndex(w http.ResponseWriter, r *http.Request) {
return
}
templates[conf.LoginTemplate].Execute(w, h.buildViewModel(r))
templates[conf.LoginTemplate].Execute(w, h.buildViewModel(r, w))
}
func (h *LoginHandler) PostLogin(w http.ResponseWriter, r *http.Request) {
@ -74,25 +76,25 @@ func (h *LoginHandler) PostLogin(w http.ResponseWriter, r *http.Request) {
var login models.Login
if err := r.ParseForm(); err != nil {
w.WriteHeader(http.StatusBadRequest)
templates[conf.LoginTemplate].Execute(w, h.buildViewModel(r).WithError("missing parameters"))
templates[conf.LoginTemplate].Execute(w, h.buildViewModel(r, w).WithError("missing parameters"))
return
}
if err := loginDecoder.Decode(&login, r.PostForm); err != nil {
w.WriteHeader(http.StatusBadRequest)
templates[conf.LoginTemplate].Execute(w, h.buildViewModel(r).WithError("missing parameters"))
templates[conf.LoginTemplate].Execute(w, h.buildViewModel(r, w).WithError("missing parameters"))
return
}
user, err := h.userSrvc.GetUserById(login.Username)
if err != nil {
w.WriteHeader(http.StatusNotFound)
templates[conf.LoginTemplate].Execute(w, h.buildViewModel(r).WithError("resource not found"))
templates[conf.LoginTemplate].Execute(w, h.buildViewModel(r, w).WithError("resource not found"))
return
}
if !utils.CompareBcrypt(user.Password, login.Password, h.config.Security.PasswordSalt) {
w.WriteHeader(http.StatusUnauthorized)
templates[conf.LoginTemplate].Execute(w, h.buildViewModel(r).WithError("invalid credentials"))
templates[conf.LoginTemplate].Execute(w, h.buildViewModel(r, w).WithError("invalid credentials"))
return
}
@ -100,7 +102,7 @@ func (h *LoginHandler) PostLogin(w http.ResponseWriter, r *http.Request) {
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
conf.Log().Request(r).Error("failed to encode secure cookie - %v", err)
templates[conf.LoginTemplate].Execute(w, h.buildViewModel(r).WithError("internal server error"))
templates[conf.LoginTemplate].Execute(w, h.buildViewModel(r, w).WithError("internal server error"))
return
}
@ -133,7 +135,7 @@ func (h *LoginHandler) GetSignup(w http.ResponseWriter, r *http.Request) {
return
}
templates[conf.SignupTemplate].Execute(w, h.buildViewModel(r))
templates[conf.SignupTemplate].Execute(w, h.buildViewModel(r, w))
}
func (h *LoginHandler) PostSignup(w http.ResponseWriter, r *http.Request) {
@ -143,7 +145,7 @@ func (h *LoginHandler) PostSignup(w http.ResponseWriter, r *http.Request) {
if !h.config.IsDev() && !h.config.Security.AllowSignup {
w.WriteHeader(http.StatusForbidden)
templates[conf.SignupTemplate].Execute(w, h.buildViewModel(r).WithError("registration is disabled on this server"))
templates[conf.SignupTemplate].Execute(w, h.buildViewModel(r, w).WithError("registration is disabled on this server"))
return
}
@ -155,18 +157,18 @@ func (h *LoginHandler) PostSignup(w http.ResponseWriter, r *http.Request) {
var signup models.Signup
if err := r.ParseForm(); err != nil {
w.WriteHeader(http.StatusBadRequest)
templates[conf.SignupTemplate].Execute(w, h.buildViewModel(r).WithError("missing parameters"))
templates[conf.SignupTemplate].Execute(w, h.buildViewModel(r, w).WithError("missing parameters"))
return
}
if err := signupDecoder.Decode(&signup, r.PostForm); err != nil {
w.WriteHeader(http.StatusBadRequest)
templates[conf.SignupTemplate].Execute(w, h.buildViewModel(r).WithError("missing parameters"))
templates[conf.SignupTemplate].Execute(w, h.buildViewModel(r, w).WithError("missing parameters"))
return
}
if !signup.IsValid() {
w.WriteHeader(http.StatusBadRequest)
templates[conf.SignupTemplate].Execute(w, h.buildViewModel(r).WithError("invalid parameters"))
templates[conf.SignupTemplate].Execute(w, h.buildViewModel(r, w).WithError("invalid parameters"))
return
}
@ -176,23 +178,24 @@ func (h *LoginHandler) PostSignup(w http.ResponseWriter, r *http.Request) {
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
conf.Log().Request(r).Error("failed to create new user - %v", err)
templates[conf.SignupTemplate].Execute(w, h.buildViewModel(r).WithError("failed to create new user"))
templates[conf.SignupTemplate].Execute(w, h.buildViewModel(r, w).WithError("failed to create new user"))
return
}
if !created {
w.WriteHeader(http.StatusConflict)
templates[conf.SignupTemplate].Execute(w, h.buildViewModel(r).WithError("user already existing"))
templates[conf.SignupTemplate].Execute(w, h.buildViewModel(r, w).WithError("user already existing"))
return
}
http.Redirect(w, r, fmt.Sprintf("%s/?success=%s", h.config.Server.BasePath, "account created successfully"), http.StatusFound)
routeutils.SetSuccess(r, w, "account created successfully")
http.Redirect(w, r, h.config.Server.BasePath, http.StatusFound)
}
func (h *LoginHandler) GetResetPassword(w http.ResponseWriter, r *http.Request) {
if h.config.IsDev() {
loadTemplates()
}
templates[conf.ResetPasswordTemplate].Execute(w, h.buildViewModel(r))
templates[conf.ResetPasswordTemplate].Execute(w, h.buildViewModel(r, w))
}
func (h *LoginHandler) GetSetPassword(w http.ResponseWriter, r *http.Request) {
@ -204,12 +207,12 @@ func (h *LoginHandler) GetSetPassword(w http.ResponseWriter, r *http.Request) {
token := values.Get("token")
if token == "" {
w.WriteHeader(http.StatusUnauthorized)
templates[conf.SetPasswordTemplate].Execute(w, h.buildViewModel(r).WithError("invalid or missing token"))
templates[conf.SetPasswordTemplate].Execute(w, h.buildViewModel(r, w).WithError("invalid or missing token"))
return
}
vm := &view.SetPasswordViewModel{
LoginViewModel: *h.buildViewModel(r),
LoginViewModel: *h.buildViewModel(r, w),
Token: token,
}
@ -224,25 +227,25 @@ func (h *LoginHandler) PostSetPassword(w http.ResponseWriter, r *http.Request) {
var setRequest models.SetPasswordRequest
if err := r.ParseForm(); err != nil {
w.WriteHeader(http.StatusBadRequest)
templates[conf.SetPasswordTemplate].Execute(w, h.buildViewModel(r).WithError("missing parameters"))
templates[conf.SetPasswordTemplate].Execute(w, h.buildViewModel(r, w).WithError("missing parameters"))
return
}
if err := signupDecoder.Decode(&setRequest, r.PostForm); err != nil {
w.WriteHeader(http.StatusBadRequest)
templates[conf.SetPasswordTemplate].Execute(w, h.buildViewModel(r).WithError("missing parameters"))
templates[conf.SetPasswordTemplate].Execute(w, h.buildViewModel(r, w).WithError("missing parameters"))
return
}
user, err := h.userSrvc.GetUserByResetToken(setRequest.Token)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
templates[conf.SetPasswordTemplate].Execute(w, h.buildViewModel(r).WithError("invalid token"))
templates[conf.SetPasswordTemplate].Execute(w, h.buildViewModel(r, w).WithError("invalid token"))
return
}
if !setRequest.IsValid() {
w.WriteHeader(http.StatusBadRequest)
templates[conf.SetPasswordTemplate].Execute(w, h.buildViewModel(r).WithError("invalid parameters"))
templates[conf.SetPasswordTemplate].Execute(w, h.buildViewModel(r, w).WithError("invalid parameters"))
return
}
@ -251,7 +254,7 @@ func (h *LoginHandler) PostSetPassword(w http.ResponseWriter, r *http.Request) {
if hash, err := utils.HashBcrypt(user.Password, h.config.Security.PasswordSalt); err != nil {
w.WriteHeader(http.StatusInternalServerError)
conf.Log().Request(r).Error("failed to set new password - %v", err)
templates[conf.SetPasswordTemplate].Execute(w, h.buildViewModel(r).WithError("failed to set new password"))
templates[conf.SetPasswordTemplate].Execute(w, h.buildViewModel(r, w).WithError("failed to set new password"))
return
} else {
user.Password = hash
@ -260,11 +263,12 @@ func (h *LoginHandler) PostSetPassword(w http.ResponseWriter, r *http.Request) {
if _, err := h.userSrvc.Update(user); err != nil {
w.WriteHeader(http.StatusInternalServerError)
conf.Log().Request(r).Error("failed to save new password - %v", err)
templates[conf.SetPasswordTemplate].Execute(w, h.buildViewModel(r).WithError("failed to save new password"))
templates[conf.SetPasswordTemplate].Execute(w, h.buildViewModel(r, w).WithError("failed to save new password"))
return
}
http.Redirect(w, r, fmt.Sprintf("%s/login?success=%s", h.config.Server.BasePath, "password updated successfully"), http.StatusFound)
routeutils.SetSuccess(r, w, "password updated successfully")
http.Redirect(w, r, fmt.Sprintf("%s/login", h.config.Server.BasePath), http.StatusFound)
}
func (h *LoginHandler) PostResetPassword(w http.ResponseWriter, r *http.Request) {
@ -274,19 +278,19 @@ func (h *LoginHandler) PostResetPassword(w http.ResponseWriter, r *http.Request)
if !h.config.Mail.Enabled {
w.WriteHeader(http.StatusNotImplemented)
templates[conf.ResetPasswordTemplate].Execute(w, h.buildViewModel(r).WithError("mailing is disabled on this server"))
templates[conf.ResetPasswordTemplate].Execute(w, h.buildViewModel(r, w).WithError("mailing is disabled on this server"))
return
}
var resetRequest models.ResetPasswordRequest
if err := r.ParseForm(); err != nil {
w.WriteHeader(http.StatusBadRequest)
templates[conf.ResetPasswordTemplate].Execute(w, h.buildViewModel(r).WithError("missing parameters"))
templates[conf.ResetPasswordTemplate].Execute(w, h.buildViewModel(r, w).WithError("missing parameters"))
return
}
if err := resetPasswordDecoder.Decode(&resetRequest, r.PostForm); err != nil {
w.WriteHeader(http.StatusBadRequest)
templates[conf.ResetPasswordTemplate].Execute(w, h.buildViewModel(r).WithError("missing parameters"))
templates[conf.ResetPasswordTemplate].Execute(w, h.buildViewModel(r, w).WithError("missing parameters"))
return
}
@ -294,7 +298,7 @@ func (h *LoginHandler) PostResetPassword(w http.ResponseWriter, r *http.Request)
if u, err := h.userSrvc.GenerateResetToken(user); err != nil {
w.WriteHeader(http.StatusInternalServerError)
conf.Log().Request(r).Error("failed to generate password reset token - %v", err)
templates[conf.ResetPasswordTemplate].Execute(w, h.buildViewModel(r).WithError("failed to generate password reset token"))
templates[conf.ResetPasswordTemplate].Execute(w, h.buildViewModel(r, w).WithError("failed to generate password reset token"))
return
} else {
go func(user *models.User) {
@ -310,16 +314,16 @@ func (h *LoginHandler) PostResetPassword(w http.ResponseWriter, r *http.Request)
conf.Log().Request(r).Warn("password reset requested for unregistered address '%s'", resetRequest.Email)
}
http.Redirect(w, r, fmt.Sprintf("%s/?success=%s", h.config.Server.BasePath, "an e-mail was sent to you in case your e-mail address was registered"), http.StatusFound)
routeutils.SetSuccess(r, w, "an e-mail was sent to you in case your e-mail address was registered")
http.Redirect(w, r, h.config.Server.BasePath, http.StatusFound)
}
func (h *LoginHandler) buildViewModel(r *http.Request) *view.LoginViewModel {
func (h *LoginHandler) buildViewModel(r *http.Request, w http.ResponseWriter) *view.LoginViewModel {
numUsers, _ := h.userSrvc.Count()
return &view.LoginViewModel{
Success: r.URL.Query().Get("success"),
Error: r.URL.Query().Get("error"),
vm := &view.LoginViewModel{
TotalUsers: int(numUsers),
AllowSignup: h.config.IsDev() || h.config.Security.AllowSignup,
}
return routeutils.WithSessionMessages(vm, r, w)
}

View File

@ -1,7 +1,6 @@
package routes
import (
"fmt"
"github.com/muety/wakapi/helpers"
"html/template"
"net/http"
@ -105,7 +104,7 @@ func loadTemplates() {
}
func defaultErrorRedirectTarget() string {
return fmt.Sprintf("%s/?error=unauthorized", config.Get().Server.BasePath)
return config.Get().Server.BasePath + "/"
}
func add(i, j int) int {

View File

@ -69,7 +69,10 @@ func NewSettingsHandler(
func (h *SettingsHandler) RegisterRoutes(router *mux.Router) {
r := router.PathPrefix("/settings").Subrouter()
r.Use(
middlewares.NewAuthenticateMiddleware(h.userSrvc).WithRedirectTarget(defaultErrorRedirectTarget()).Handler,
middlewares.NewAuthenticateMiddleware(h.userSrvc).
WithRedirectTarget(defaultErrorRedirectTarget()).
WithRedirectErrorMessage("unauthorized").
Handler,
)
r.Methods(http.MethodGet).HandlerFunc(h.GetIndex)
r.Methods(http.MethodPost).HandlerFunc(h.PostIndex)
@ -79,7 +82,7 @@ func (h *SettingsHandler) GetIndex(w http.ResponseWriter, r *http.Request) {
if h.config.IsDev() {
loadTemplates()
}
templates[conf.SettingsTemplate].Execute(w, h.buildViewModel(r))
templates[conf.SettingsTemplate].Execute(w, h.buildViewModel(r, w))
}
func (h *SettingsHandler) PostIndex(w http.ResponseWriter, r *http.Request) {
@ -89,7 +92,7 @@ func (h *SettingsHandler) PostIndex(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
w.WriteHeader(http.StatusBadRequest)
templates[conf.SettingsTemplate].Execute(w, h.buildViewModel(r).WithError("missing form values"))
templates[conf.SettingsTemplate].Execute(w, h.buildViewModel(r, w).WithError("missing form values"))
return
}
@ -100,7 +103,7 @@ func (h *SettingsHandler) PostIndex(w http.ResponseWriter, r *http.Request) {
if actionFunc == nil {
logbuch.Warn("failed to dispatch action '%s'", action)
w.WriteHeader(http.StatusBadRequest)
templates[conf.SettingsTemplate].Execute(w, h.buildViewModel(r).WithError("unknown action requests"))
templates[conf.SettingsTemplate].Execute(w, h.buildViewModel(r, w).WithError("unknown action requests"))
return
}
@ -113,15 +116,15 @@ func (h *SettingsHandler) PostIndex(w http.ResponseWriter, r *http.Request) {
if errorMsg != "" {
w.WriteHeader(status)
templates[conf.SettingsTemplate].Execute(w, h.buildViewModel(r).WithError(errorMsg))
templates[conf.SettingsTemplate].Execute(w, h.buildViewModel(r, w).WithError(errorMsg))
return
}
if successMsg != "" {
w.WriteHeader(status)
templates[conf.SettingsTemplate].Execute(w, h.buildViewModel(r).WithSuccess(successMsg))
templates[conf.SettingsTemplate].Execute(w, h.buildViewModel(r, w).WithSuccess(successMsg))
return
}
templates[conf.SettingsTemplate].Execute(w, h.buildViewModel(r))
templates[conf.SettingsTemplate].Execute(w, h.buildViewModel(r, w))
}
func (h *SettingsHandler) dispatchAction(action string) action {
@ -622,8 +625,9 @@ func (h *SettingsHandler) actionDeleteUser(w http.ResponseWriter, r *http.Reques
}
}(user)
routeutils.SetSuccess(r, w, "Your account will be deleted in a few minutes. Sorry to you go.")
http.SetCookie(w, h.config.GetClearCookie(models.AuthCookieKey))
http.Redirect(w, r, fmt.Sprintf("%s/?success=%s", h.config.Server.BasePath, "Your account will be deleted in a few minutes. Sorry to you go."), http.StatusFound)
http.Redirect(w, r, h.config.Server.BasePath, http.StatusFound)
return -1, "", ""
}
@ -673,7 +677,7 @@ func (h *SettingsHandler) regenerateSummaries(user *models.User) error {
return nil
}
func (h *SettingsHandler) buildViewModel(r *http.Request) *view.SettingsViewModel {
func (h *SettingsHandler) buildViewModel(r *http.Request, w http.ResponseWriter) *view.SettingsViewModel {
user := middlewares.GetPrincipal(r)
// mappings
@ -683,7 +687,7 @@ func (h *SettingsHandler) buildViewModel(r *http.Request) *view.SettingsViewMode
aliases, err := h.aliasSrvc.GetByUser(user.ID)
if err != nil {
conf.Log().Request(r).Error("error while building alias map - %v", err)
return &view.SettingsViewModel{Error: criticalError}
return &view.SettingsViewModel{Messages: view.Messages{Error: criticalError}}
}
aliasMap := make(map[string][]*models.Alias)
for _, a := range aliases {
@ -712,7 +716,7 @@ func (h *SettingsHandler) buildViewModel(r *http.Request) *view.SettingsViewMode
labelMap, err := h.projectLabelSrvc.GetByUserGroupedInverted(user.ID)
if err != nil {
conf.Log().Request(r).Error("error while building settings project label map - %v", err)
return &view.SettingsViewModel{Error: criticalError}
return &view.SettingsViewModel{Messages: view.Messages{Error: criticalError}}
}
combinedLabels := make([]*view.SettingsVMCombinedLabel, 0)
@ -734,7 +738,7 @@ func (h *SettingsHandler) buildViewModel(r *http.Request) *view.SettingsViewMode
projects, err := routeutils.GetEffectiveProjectsList(user, h.heartbeatSrvc, h.aliasSrvc)
if err != nil {
conf.Log().Request(r).Error("error while fetching projects - %v", err)
return &view.SettingsViewModel{Error: criticalError}
return &view.SettingsViewModel{Messages: view.Messages{Error: criticalError}}
}
// subscriptions
@ -750,7 +754,7 @@ func (h *SettingsHandler) buildViewModel(r *http.Request) *view.SettingsViewMode
firstData, _ = time.Parse(time.RFC822Z, firstDataKv.Value)
}
return &view.SettingsViewModel{
vm := &view.SettingsViewModel{
User: user,
LanguageMappings: mappings,
Aliases: combinedAliases,
@ -761,7 +765,6 @@ func (h *SettingsHandler) buildViewModel(r *http.Request) *view.SettingsViewMode
SubscriptionPrice: subscriptionPrice,
SupportContact: h.config.App.SupportContact,
DataRetentionMonths: h.config.App.DataRetentionMonths,
Success: r.URL.Query().Get("success"),
Error: r.URL.Query().Get("error"),
}
return routeutils.WithSessionMessages(vm, r, w)
}

View File

@ -9,6 +9,7 @@ import (
conf "github.com/muety/wakapi/config"
"github.com/muety/wakapi/middlewares"
"github.com/muety/wakapi/models"
routeutils "github.com/muety/wakapi/routes/utils"
"github.com/muety/wakapi/services"
"github.com/stripe/stripe-go/v74"
stripePortalSession "github.com/stripe/stripe-go/v74/billingportal/session"
@ -81,7 +82,10 @@ func (h *SubscriptionHandler) RegisterRoutes(router *mux.Router) {
subRouterPrivate := subRouterPublic.PathPrefix("").Subrouter()
subRouterPrivate.Use(
middlewares.NewAuthenticateMiddleware(h.userSrvc).WithRedirectTarget(defaultErrorRedirectTarget()).Handler,
middlewares.NewAuthenticateMiddleware(h.userSrvc).
WithRedirectTarget(defaultErrorRedirectTarget()).
WithRedirectErrorMessage("unauthorized").
Handler,
)
subRouterPrivate.Path("/checkout").Methods(http.MethodPost).HandlerFunc(h.PostCheckout)
subRouterPrivate.Path("/portal").Methods(http.MethodPost).HandlerFunc(h.PostPortal)
@ -94,12 +98,14 @@ func (h *SubscriptionHandler) PostCheckout(w http.ResponseWriter, r *http.Reques
user := middlewares.GetPrincipal(r)
if user.Email == "" {
http.Redirect(w, r, fmt.Sprintf("%s/settings?error=%s#subscription", h.config.Server.BasePath, "missing e-mail address"), http.StatusFound)
routeutils.SetError(r, w, "missing e-mail address")
http.Redirect(w, r, fmt.Sprintf("%s/settings#subscription", h.config.Server.BasePath), http.StatusFound)
return
}
if err := r.ParseForm(); err != nil {
http.Redirect(w, r, fmt.Sprintf("%s/settings?error=%s#subscription", h.config.Server.BasePath, "missing form values"), http.StatusFound)
routeutils.SetError(r, w, "missing form values")
http.Redirect(w, r, fmt.Sprintf("%s/settings#subscription", h.config.Server.BasePath), http.StatusFound)
return
}
@ -125,7 +131,8 @@ func (h *SubscriptionHandler) PostCheckout(w http.ResponseWriter, r *http.Reques
session, err := stripeCheckoutSession.New(checkoutParams)
if err != nil {
conf.Log().Request(r).Error("failed to create stripe checkout session: %v", err)
http.Redirect(w, r, fmt.Sprintf("%s/settings?error=%s#subscription", h.config.Server.BasePath, "something went wrong"), http.StatusFound)
routeutils.SetError(r, w, "something went wrong")
http.Redirect(w, r, fmt.Sprintf("%s/settings#subscription", h.config.Server.BasePath), http.StatusFound)
return
}
@ -139,7 +146,8 @@ func (h *SubscriptionHandler) PostPortal(w http.ResponseWriter, r *http.Request)
user := middlewares.GetPrincipal(r)
if user.StripeCustomerId == "" {
http.Redirect(w, r, fmt.Sprintf("%s/settings?error=%s#subscription", h.config.Server.BasePath, "no subscription found with your e-mail address, please contact us!"), http.StatusFound)
routeutils.SetError(r, w, "no subscription found with your e-mail address, please contact us!")
http.Redirect(w, r, fmt.Sprintf("%s/settings#subscription", h.config.Server.BasePath), http.StatusFound)
return
}
@ -151,7 +159,8 @@ func (h *SubscriptionHandler) PostPortal(w http.ResponseWriter, r *http.Request)
session, err := stripePortalSession.New(portalParams)
if err != nil {
conf.Log().Request(r).Error("failed to create stripe portal session: %v", err)
http.Redirect(w, r, fmt.Sprintf("%s/settings?error=%s#subscription", h.config.Server.BasePath, "something went wrong"), http.StatusFound)
routeutils.SetError(r, w, "something went wrong")
http.Redirect(w, r, fmt.Sprintf("%s/settings#subscription", h.config.Server.BasePath), http.StatusFound)
return
}
@ -251,7 +260,8 @@ func (h *SubscriptionHandler) PostWebhook(w http.ResponseWriter, r *http.Request
}
func (h *SubscriptionHandler) GetCheckoutSuccess(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, fmt.Sprintf("%s/settings?success=%s#subscription", h.config.Server.BasePath, "you have successfully subscribed to Wakapi!"), http.StatusFound)
routeutils.SetSuccess(r, w, "you have successfully subscribed to Wakapi!")
http.Redirect(w, r, fmt.Sprintf("%s/settings", h.config.Server.BasePath), http.StatusFound)
}
func (h *SubscriptionHandler) GetCheckoutCancel(w http.ResponseWriter, r *http.Request) {

View File

@ -29,11 +29,17 @@ func NewSummaryHandler(summaryService services.ISummaryService, userService serv
func (h *SummaryHandler) RegisterRoutes(router *mux.Router) {
r1 := router.PathPrefix("/summary").Subrouter()
r1.Use(middlewares.NewAuthenticateMiddleware(h.userSrvc).WithRedirectTarget(defaultErrorRedirectTarget()).Handler)
r1.Use(middlewares.NewAuthenticateMiddleware(h.userSrvc).
WithRedirectTarget(defaultErrorRedirectTarget()).
WithRedirectErrorMessage("unauthorized").
Handler)
r1.Methods(http.MethodGet).HandlerFunc(h.GetIndex)
r2 := router.PathPrefix("/summary").Subrouter()
r2.Use(middlewares.NewAuthenticateMiddleware(h.userSrvc).WithRedirectTarget(defaultErrorRedirectTarget()).Handler)
r2.Use(middlewares.NewAuthenticateMiddleware(h.userSrvc).
WithRedirectTarget(defaultErrorRedirectTarget()).
WithRedirectErrorMessage("unauthorized").
Handler)
r2.Methods(http.MethodGet).HandlerFunc(h.GetIndex)
}
@ -64,14 +70,14 @@ func (h *SummaryHandler) GetIndex(w http.ResponseWriter, r *http.Request) {
if err != nil {
w.WriteHeader(status)
conf.Log().Request(r).Error("failed to load summary - %v", err)
templates[conf.SummaryTemplate].Execute(w, h.buildViewModel(r).WithError(err.Error()))
templates[conf.SummaryTemplate].Execute(w, h.buildViewModel(r, w).WithError(err.Error()))
return
}
user := middlewares.GetPrincipal(r)
if user == nil {
w.WriteHeader(http.StatusUnauthorized)
templates[conf.SummaryTemplate].Execute(w, h.buildViewModel(r).WithError("unauthorized"))
templates[conf.SummaryTemplate].Execute(w, h.buildViewModel(r, w).WithError("unauthorized"))
return
}
@ -89,9 +95,6 @@ func (h *SummaryHandler) GetIndex(w http.ResponseWriter, r *http.Request) {
templates[conf.SummaryTemplate].Execute(w, vm)
}
func (h *SummaryHandler) buildViewModel(r *http.Request) *view.SummaryViewModel {
return &view.SummaryViewModel{
Success: r.URL.Query().Get("success"),
Error: r.URL.Query().Get("error"),
}
func (h *SummaryHandler) buildViewModel(r *http.Request, w http.ResponseWriter) *view.SummaryViewModel {
return su.WithSessionMessages(&view.SummaryViewModel{}, r, w)
}

33
routes/utils/messages.go Normal file
View File

@ -0,0 +1,33 @@
package utils
import (
conf "github.com/muety/wakapi/config"
"github.com/muety/wakapi/models/view"
"net/http"
)
func SetError(r *http.Request, w http.ResponseWriter, message string) {
setMessage(r, w, message, "error")
}
func SetSuccess(r *http.Request, w http.ResponseWriter, message string) {
setMessage(r, w, message, "success")
}
func WithSessionMessages[T view.BasicViewModel](vm T, r *http.Request, w http.ResponseWriter) T {
session, _ := conf.GetSessionStore().Get(r, conf.SessionKeyDefault)
if errors := session.Flashes("error"); len(errors) > 0 {
vm.SetError(errors[0].(string))
}
if successes := session.Flashes("success"); len(successes) > 0 {
vm.SetSuccess(successes[0].(string))
}
session.Save(r, w)
return vm
}
func setMessage(r *http.Request, w http.ResponseWriter, message, key string) {
session, _ := conf.GetSessionStore().Get(r, conf.SessionKeyDefault)
session.AddFlash(message, key)
session.Save(r, w)
}

View File

@ -1,13 +1,13 @@
{{ if .Error }}
<div class="flex justify-center w-full">
<div class="p-4 font-semibold text-white text-sm bg-red-500 rounded mt-16 shadow grow max-w-lg">
Error: {{ .Error | capitalize }}
Error: {{ .Messages.Error | capitalize }}
</div>
</div>
{{ else if .Success }}
<div class="flex justify-center w-full">
<div class="p-4 font-semibold text-white text-sm bg-green-500 rounded mt-16 shadow grow max-w-lg">
{{ .Success | capitalize }}
{{ .Messages.Success | capitalize }}
</div>
</div>
{{ end }}