diff --git a/config.default.yml b/config.default.yml index a6f159c..e2ba27b 100644 --- a/config.default.yml +++ b/config.default.yml @@ -65,6 +65,7 @@ sentry: # only relevant for running wakapi as a hosted service with paid subscriptions and stripe payments subscriptions: enabled: false + expiry_notifications: true stripe_api_key: stripe_secret_key: stripe_endpoint_secret: diff --git a/models/user.go b/models/user.go index 8f0d8e3..701bc1f 100644 --- a/models/user.go +++ b/models/user.go @@ -36,6 +36,7 @@ type User struct { ReportsWeekly bool `json:"-" gorm:"default:false; type:bool"` PublicLeaderboard bool `json:"-" gorm:"default:false; type:bool"` SubscribedUntil *CustomTime `json:"-" gorm:"type:timestamp" swaggertype:"string" format:"date" example:"2006-01-02 15:04:05.000"` + StripeCustomerId string `json:"-"` } type Login struct { diff --git a/repositories/repositories.go b/repositories/repositories.go index 8b36eb7..5ea3ed5 100644 --- a/repositories/repositories.go +++ b/repositories/repositories.go @@ -72,11 +72,8 @@ type ISummaryRepository interface { } type IUserRepository interface { - GetById(string) (*models.User, error) + FindOne(user models.User) (*models.User, error) GetByIds([]string) ([]*models.User, error) - GetByApiKey(string) (*models.User, error) - GetByEmail(string) (*models.User, error) - GetByResetToken(string) (*models.User, error) GetAll() ([]*models.User, error) GetMany([]string) ([]*models.User, error) GetAllByReports(bool) ([]*models.User, error) diff --git a/repositories/user.go b/repositories/user.go index 7d6757f..e5ebd09 100644 --- a/repositories/user.go +++ b/repositories/user.go @@ -15,9 +15,9 @@ func NewUserRepository(db *gorm.DB) *UserRepository { return &UserRepository{db: db} } -func (r *UserRepository) GetById(userId string) (*models.User, error) { +func (r *UserRepository) FindOne(attributes models.User) (*models.User, error) { u := &models.User{} - if err := r.db.Where(&models.User{ID: userId}).First(u).Error; err != nil { + if err := r.db.Where(&attributes).First(u).Error; err != nil { return u, err } return u, nil @@ -34,39 +34,6 @@ func (r *UserRepository) GetByIds(userIds []string) ([]*models.User, error) { return users, nil } -func (r *UserRepository) GetByApiKey(key string) (*models.User, error) { - if key == "" { - return nil, errors.New("invalid input") - } - u := &models.User{} - if err := r.db.Where(&models.User{ApiKey: key}).First(u).Error; err != nil { - return u, err - } - return u, nil -} - -func (r *UserRepository) GetByResetToken(resetToken string) (*models.User, error) { - if resetToken == "" { - return nil, errors.New("invalid input") - } - u := &models.User{} - if err := r.db.Where(&models.User{ResetToken: resetToken}).First(u).Error; err != nil { - return u, err - } - return u, nil -} - -func (r *UserRepository) GetByEmail(email string) (*models.User, error) { - if email == "" { - return nil, errors.New("invalid input") - } - u := &models.User{} - if err := r.db.Where(&models.User{Email: email}).First(u).Error; err != nil { - return u, err - } - return u, nil -} - func (r *UserRepository) GetAll() ([]*models.User, error) { var users []*models.User if err := r.db. @@ -144,7 +111,7 @@ func (r *UserRepository) Count() (int64, error) { } func (r *UserRepository) InsertOrGet(user *models.User) (*models.User, bool, error) { - if u, err := r.GetById(user.ID); err == nil && u != nil && u.ID != "" { + if u, err := r.FindOne(models.User{ID: user.ID}); err == nil && u != nil && u.ID != "" { return u, false, nil } @@ -177,6 +144,7 @@ func (r *UserRepository) Update(user *models.User) (*models.User, error) { "reports_weekly": user.ReportsWeekly, "public_leaderboard": user.PublicLeaderboard, "subscribed_until": user.SubscribedUntil, + "stripe_customer_id": user.StripeCustomerId, } result := r.db.Model(user).Updates(updateMap) diff --git a/routes/settings.go b/routes/settings.go index b99ae2d..46743c6 100644 --- a/routes/settings.go +++ b/routes/settings.go @@ -181,7 +181,7 @@ func (h *SettingsHandler) actionUpdateUser(w http.ResponseWriter, r *http.Reques return http.StatusBadRequest, "", "invalid parameters" } - if user.Email == "" && user.HasActiveSubscription() { + if payload.Email == "" && user.HasActiveSubscription() { return http.StatusBadRequest, "", "cannot unset email while subscription is active" } diff --git a/routes/subscription.go b/routes/subscription.go index 8c29944..574598d 100644 --- a/routes/subscription.go +++ b/routes/subscription.go @@ -22,6 +22,15 @@ import ( "time" ) +/* + How to integrate with Stripe? + --- + 1. Create a plan with recurring payment (https://dashboard.stripe.com/test/products?active=true), copy its ID and save it as 'standard_price_id' + 2. Create a webhook (https://dashboard.stripe.com/test/webhooks), with target URL '/subscription/webhook' and events ['customer.subscription.created', 'customer.subscription.updated', 'customer.subscription.deleted', 'checkout.session.completed'], copy the endpoint secret and save it to 'stripe_endpoint_secret' + 3. Create a secret API key (https://dashboard.stripe.com/test/apikeys), copy it and save it to 'stripe_secret_key' + 4. Copy the publishable API key (https://dashboard.stripe.com/test/apikeys) and save it to 'stripe_api_key' +*/ + type SubscriptionHandler struct { config *conf.Config userSrvc services.IUserService @@ -102,12 +111,17 @@ func (h *SubscriptionHandler) PostCheckout(w http.ResponseWriter, r *http.Reques Quantity: stripe.Int64(1), }, }, - CustomerEmail: &user.Email, - ClientReferenceID: &user.Email, + ClientReferenceID: &user.ID, SuccessURL: stripe.String(fmt.Sprintf("%s%s/subscription/success", h.config.Server.PublicUrl, h.config.Server.BasePath)), CancelURL: stripe.String(fmt.Sprintf("%s%s/subscription/cancel", h.config.Server.PublicUrl, h.config.Server.BasePath)), } + if user.StripeCustomerId != "" { + checkoutParams.Customer = &user.StripeCustomerId + } else { + checkoutParams.CustomerEmail = &user.Email + } + session, err := stripeCheckoutSession.New(checkoutParams) if err != nil { conf.Log().Request(r).Error("failed to create stripe checkout session: %v", err) @@ -124,19 +138,13 @@ func (h *SubscriptionHandler) PostPortal(w http.ResponseWriter, r *http.Request) } user := middlewares.GetPrincipal(r) - if user.Email == "" { - http.Redirect(w, r, fmt.Sprintf("%s/settings?error=%s#subscription", h.config.Server.BasePath, "no subscription found with your e-mail address, please contact us!"), http.StatusFound) - return - } - - customer, err := h.findStripeCustomerByEmail(user.Email) - if err != nil { + if user.StripeCustomerId == "" { http.Redirect(w, r, fmt.Sprintf("%s/settings?error=%s#subscription", h.config.Server.BasePath, "no subscription found with your e-mail address, please contact us!"), http.StatusFound) return } portalParams := &stripe.BillingPortalSessionParams{ - Customer: &customer.ID, + Customer: &user.StripeCustomerId, ReturnURL: &h.config.Server.PublicUrl, } @@ -173,24 +181,68 @@ func (h *SubscriptionHandler) PostWebhook(w http.ResponseWriter, r *http.Request "customer.subscription.updated", "customer.subscription.created": // example payload: https://pastr.de/p/k7bx3alx38b1iawo6amtx09k - subscription, customer, err := h.parseSubscriptionEvent(w, r, event) + subscription, err := h.parseSubscriptionEvent(w, r, event) if err != nil { - return - } - logbuch.Info("received stripe subscription event of type '%s' for subscription '%s' (customer '%s' with email '%s').", event.Type, subscription.ID, customer.ID, customer.Email) - - user, err := h.userSrvc.GetUserByEmail(customer.Email) - if err != nil { - conf.Log().Request(r).Error("failed to find user with e-mail '%s' to update their subscription (status '%s')", subscription.Status) w.WriteHeader(http.StatusInternalServerError) return } + logbuch.Info("received stripe subscription event of type '%s' for subscription '%s' (customer '%s').", event.Type, subscription.ID, subscription.Customer.ID) + + // first, try to get user by associated customer id (requires checkout.session.completed event to have been processed before) + user, err := h.userSrvc.GetUserByStripeCustomerId(subscription.Customer.ID) + if err != nil { + conf.Log().Request(r).Warn("failed to find user with stripe customer id '%s' to update their subscription (status '%s')", subscription.Customer.ID, subscription.Status) + + // second, resolve customer and try to get user by email + customer, err := stripeCustomer.Get(subscription.Customer.ID, nil) + if err != nil { + conf.Log().Request(r).Error("failed to fetch stripe customer with id '%s', %v", subscription.Customer.ID, err) + w.WriteHeader(http.StatusInternalServerError) + return + } + + u, err := h.userSrvc.GetUserByEmail(customer.Email) + if err != nil { + conf.Log().Request(r).Error("failed to get user with email '%s' as stripe customer '%s' for processing event for subscription %s, %v", customer.Email, subscription.Customer.ID, subscription.ID, err) + w.WriteHeader(http.StatusInternalServerError) + return + } + user = u + } if err := h.handleSubscriptionEvent(subscription, user); err != nil { conf.Log().Request(r).Error("failed to handle subscription event %s (%s) for user %s, %v", event.ID, event.Type, user.ID, err) w.WriteHeader(http.StatusInternalServerError) return } + + case "checkout.session.completed": + // example payload: https://pastr.de/p/d01iniw9naq9hkmvyqtxin2w + checkoutSession, err := h.parseCheckoutSessionEvent(w, r, event) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + logbuch.Info("received stripe checkout session event of type '%s' for session '%s' (customer '%s' with email '%s').", event.Type, checkoutSession.ID, checkoutSession.Customer.ID, checkoutSession.CustomerEmail) + + user, err := h.userSrvc.GetUserById(checkoutSession.ClientReferenceID) + if err != nil { + conf.Log().Request(r).Error("failed to find user with id '%s' to update associated stripe customer (%s)", user.ID, checkoutSession.Customer.ID) + w.WriteHeader(http.StatusInternalServerError) + return + } + + if user.StripeCustomerId == "" { + user.StripeCustomerId = checkoutSession.Customer.ID + if _, err := h.userSrvc.Update(user); err != nil { + conf.Log().Request(r).Error("failed to update stripe customer id (%s) for user '%s', %v", checkoutSession.Customer.ID, user.ID, err) + } else { + logbuch.Info("associated user '%s' with stripe customer '%s'", user.ID, checkoutSession.Customer.ID) + } + } else if user.StripeCustomerId != checkoutSession.Customer.ID { + conf.Log().Request(r).Error("invalid state: tried to associate user '%s' with stripe customer '%s', but '%s' already assigned", user.ID, checkoutSession.Customer.ID, user.StripeCustomerId) + } + default: logbuch.Warn("got stripe event '%s' with no handler defined", event.Type) } @@ -219,7 +271,7 @@ func (h *SubscriptionHandler) handleSubscriptionEvent(subscription *stripe.Subsc logbuch.Info("user %s got active subscription %s until %v", user.ID, subscription.ID, user.SubscribedUntil) } - if cancelAt := time.Unix(subscription.CancelAt, 0); !cancelAt.IsZero() { + if cancelAt := time.Unix(subscription.CancelAt, 0); !cancelAt.IsZero() && cancelAt.After(time.Now()) { logbuch.Info("user %s chose to cancel subscription %s by %v", user.ID, subscription.ID, cancelAt) } case "canceled", "unpaid": @@ -227,6 +279,7 @@ func (h *SubscriptionHandler) handleSubscriptionEvent(subscription *stripe.Subsc logbuch.Info("user %s's subscription %s got canceled, because of status update to '%s'", user.ID, subscription.ID, subscription.Status) default: logbuch.Info("got subscription (%s) status update to '%s' for user '%s'", subscription.ID, subscription.Status, user.ID) + return nil } _, err := h.userSrvc.Update(user) @@ -236,24 +289,25 @@ func (h *SubscriptionHandler) handleSubscriptionEvent(subscription *stripe.Subsc return err } -func (h *SubscriptionHandler) parseSubscriptionEvent(w http.ResponseWriter, r *http.Request, event stripe.Event) (*stripe.Subscription, *stripe.Customer, error) { +func (h *SubscriptionHandler) parseSubscriptionEvent(w http.ResponseWriter, r *http.Request, event stripe.Event) (*stripe.Subscription, error) { var subscription stripe.Subscription if err := json.Unmarshal(event.Data.Raw, &subscription); err != nil { conf.Log().Request(r).Error("failed to parse stripe webhook payload: %v", err) w.WriteHeader(http.StatusBadRequest) - return nil, nil, err + return nil, err } + return &subscription, nil +} - customer, err := stripeCustomer.Get(subscription.Customer.ID, nil) - if err != nil { - conf.Log().Request(r).Error("failed to fetch stripe customer (%s): %v", subscription.Customer.ID, err) +func (h *SubscriptionHandler) parseCheckoutSessionEvent(w http.ResponseWriter, r *http.Request, event stripe.Event) (*stripe.CheckoutSession, error) { + var checkoutSession stripe.CheckoutSession + if err := json.Unmarshal(event.Data.Raw, &checkoutSession); err != nil { + conf.Log().Request(r).Error("failed to parse stripe webhook payload: %v", err) w.WriteHeader(http.StatusBadRequest) - return nil, nil, err + return nil, err } - logbuch.Info("associated stripe customer %s with user %s", customer.ID, customer.Email) - - return &subscription, customer, nil + return &checkoutSession, nil } func (h *SubscriptionHandler) findStripeCustomerByEmail(email string) (*stripe.Customer, error) { @@ -278,6 +332,6 @@ func (h *SubscriptionHandler) findStripeCustomerByEmail(email string) (*stripe.C func (h *SubscriptionHandler) clearSubscriptionNotificationStatus(userId string) { key := fmt.Sprintf("%s_%s", conf.KeySubscriptionNotificationSent, userId) if err := h.keyValueSrvc.DeleteString(key); err != nil { - conf.Log().Error("failed to delete '%s', %v", key, err) + logbuch.Warn("failed to delete '%s', %v", key, err) } } diff --git a/services/services.go b/services/services.go index d56c8f4..d0da981 100644 --- a/services/services.go +++ b/services/services.go @@ -125,6 +125,7 @@ type IUserService interface { GetUserByKey(string) (*models.User, error) GetUserByEmail(string) (*models.User, error) GetUserByResetToken(string) (*models.User, error) + GetUserByStripeCustomerId(string) (*models.User, error) GetAll() ([]*models.User, error) GetMany([]string) ([]*models.User, error) GetManyMapped([]string) (map[string]*models.User, error) diff --git a/services/user.go b/services/user.go index b8d4977..7c10e2f 100644 --- a/services/user.go +++ b/services/user.go @@ -62,7 +62,7 @@ func (srv *UserService) GetUserById(userId string) (*models.User, error) { return u.(*models.User), nil } - u, err := srv.repository.GetById(userId) + u, err := srv.repository.FindOne(models.User{ID: userId}) if err != nil { return nil, err } @@ -76,7 +76,7 @@ func (srv *UserService) GetUserByKey(key string) (*models.User, error) { return u.(*models.User), nil } - u, err := srv.repository.GetByApiKey(key) + u, err := srv.repository.FindOne(models.User{ApiKey: key}) if err != nil { return nil, err } @@ -86,11 +86,15 @@ func (srv *UserService) GetUserByKey(key string) (*models.User, error) { } func (srv *UserService) GetUserByEmail(email string) (*models.User, error) { - return srv.repository.GetByEmail(email) + return srv.repository.FindOne(models.User{Email: email}) } func (srv *UserService) GetUserByResetToken(resetToken string) (*models.User, error) { - return srv.repository.GetByResetToken(resetToken) + return srv.repository.FindOne(models.User{ResetToken: resetToken}) +} + +func (srv *UserService) GetUserByStripeCustomerId(customerId string) (*models.User, error) { + return srv.repository.FindOne(models.User{StripeCustomerId: customerId}) } func (srv *UserService) GetAll() ([]*models.User, error) { diff --git a/views/settings.tpl.html b/views/settings.tpl.html index bd5598b..7bce4c1 100644 --- a/views/settings.tpl.html +++ b/views/settings.tpl.html @@ -81,9 +81,6 @@ Optional in general, but required for weekly reports and for resetting your password. - {{ if .User.HasActiveSubscription }} - You cannot unset or change your e-mail address while you have an active subscription. - {{ end }}
@@ -91,7 +88,6 @@ type="email" id="email" name="email" placeholder="Enter your e-mail address" value="{{ .User.Email }}" - {{ if .User.HasActiveSubscription }}disabled{{ end }} >