diff --git a/middlewares/authenticate.go b/middlewares/authenticate.go index 86f215a..7af14fa 100644 --- a/middlewares/authenticate.go +++ b/middlewares/authenticate.go @@ -5,21 +5,16 @@ import ( "errors" "fmt" conf "github.com/muety/wakapi/config" + "github.com/muety/wakapi/models" + "github.com/muety/wakapi/services" "github.com/muety/wakapi/utils" "log" "net/http" "strings" - "time" - - "github.com/patrickmn/go-cache" - - "github.com/muety/wakapi/models" - "github.com/muety/wakapi/services" ) type AuthenticateMiddleware struct { config *conf.Config - cache *cache.Cache userSrvc services.IUserService whitelistPaths []string } @@ -28,7 +23,6 @@ func NewAuthenticateMiddleware(userService services.IUserService, whitelistPaths return &AuthenticateMiddleware{ config: conf.Get(), userSrvc: userService, - cache: cache.New(1*time.Hour, 2*time.Hour), whitelistPaths: whitelistPaths, } } @@ -64,8 +58,6 @@ func (m *AuthenticateMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Reques return } - m.cache.Set(user.ID, user, cache.DefaultExpiration) - ctx := context.WithValue(r.Context(), models.UserKey, user) next(w, r.WithContext(ctx)) } @@ -78,14 +70,9 @@ func (m *AuthenticateMiddleware) tryGetUserByApiKey(r *http.Request) (*models.Us var user *models.User userKey := strings.TrimSpace(key) - cachedUser, ok := m.cache.Get(userKey) - if !ok { - user, err = m.userSrvc.GetUserByKey(userKey) - if err != nil { - return nil, err - } - } else { - user = cachedUser.(*models.User) + user, err = m.userSrvc.GetUserByKey(userKey) + if err != nil { + return nil, err } return user, nil } @@ -96,12 +83,6 @@ func (m *AuthenticateMiddleware) tryGetUserByCookie(r *http.Request) (*models.Us return nil, err } - cachedUser, ok := m.cache.Get(login.Username) - - if ok { - return cachedUser.(*models.User), nil - } - user, err := m.userSrvc.GetUserById(login.Username) if err != nil { return nil, err diff --git a/services/user.go b/services/user.go index bdb0ee5..27d298c 100644 --- a/services/user.go +++ b/services/user.go @@ -5,11 +5,14 @@ import ( "github.com/muety/wakapi/models" "github.com/muety/wakapi/repositories" "github.com/muety/wakapi/utils" + "github.com/patrickmn/go-cache" uuid "github.com/satori/go.uuid" + "time" ) type UserService struct { Config *config.Config + cache *cache.Cache repository repositories.IUserRepository } @@ -17,15 +20,36 @@ func NewUserService(userRepo repositories.IUserRepository) *UserService { return &UserService{ Config: config.Get(), repository: userRepo, + cache: cache.New(1*time.Hour, 2*time.Hour), } } func (srv *UserService) GetUserById(userId string) (*models.User, error) { - return srv.repository.GetById(userId) + if u, ok := srv.cache.Get(userId); ok { + return u.(*models.User), nil + } + + u, err := srv.repository.GetById(userId) + if err != nil { + return nil, err + } + + srv.cache.Set(u.ID, u, cache.DefaultExpiration) + return u, nil } func (srv *UserService) GetUserByKey(key string) (*models.User, error) { - return srv.repository.GetByApiKey(key) + if u, ok := srv.cache.Get(key); ok { + return u.(*models.User), nil + } + + u, err := srv.repository.GetByApiKey(key) + if err != nil { + return nil, err + } + + srv.cache.Set(u.ID, u, cache.DefaultExpiration) + return u, nil } func (srv *UserService) GetAll() ([]*models.User, error) { @@ -49,19 +73,23 @@ func (srv *UserService) CreateOrGet(signup *models.Signup) (*models.User, bool, } func (srv *UserService) Update(user *models.User) (*models.User, error) { + srv.cache.Flush() return srv.repository.Update(user) } func (srv *UserService) ResetApiKey(user *models.User) (*models.User, error) { + srv.cache.Flush() user.ApiKey = uuid.NewV4().String() return srv.Update(user) } func (srv *UserService) ToggleBadges(user *models.User) (*models.User, error) { + srv.cache.Flush() return srv.repository.UpdateField(user, "badges_enabled", !user.BadgesEnabled) } func (srv *UserService) MigrateMd5Password(user *models.User, login *models.Login) (*models.User, error) { + srv.cache.Flush() user.Password = login.Password if hash, err := utils.HashBcrypt(user.Password, srv.Config.Security.PasswordSalt); err != nil { return nil, err