package routes import ( "encoding/json" "errors" "fmt" "github.com/emvi/logbuch" "github.com/go-chi/chi/v5" conf "github.com/muety/wakapi/config" "github.com/muety/wakapi/middlewares" "github.com/muety/wakapi/models" routeutils "github.com/muety/wakapi/routes/utils" "github.com/muety/wakapi/services" "github.com/stripe/stripe-go/v74" stripePortalSession "github.com/stripe/stripe-go/v74/billingportal/session" stripeCheckoutSession "github.com/stripe/stripe-go/v74/checkout/session" stripeCustomer "github.com/stripe/stripe-go/v74/customer" stripePrice "github.com/stripe/stripe-go/v74/price" "github.com/stripe/stripe-go/v74/webhook" "io/ioutil" "net/http" "strings" "time" ) /* How to integrate with Stripe? --- 1. Create a plan with recurring payment (https://dashboard.stripe.com/test/products?active=true), copy its ID and save it as 'standard_price_id' 2. Create a webhook (https://dashboard.stripe.com/test/webhooks), with target URL '/subscription/webhook' and events ['customer.subscription.created', 'customer.subscription.updated', 'customer.subscription.deleted', 'checkout.session.completed'], copy the endpoint secret and save it to 'stripe_endpoint_secret' 3. Create a secret API key (https://dashboard.stripe.com/test/apikeys), copy it and save it to 'stripe_secret_key' 4. Copy the publishable API key (https://dashboard.stripe.com/test/apikeys) and save it to 'stripe_api_key' */ type SubscriptionHandler struct { config *conf.Config userSrvc services.IUserService mailSrvc services.IMailService keyValueSrvc services.IKeyValueService httpClient *http.Client } func NewSubscriptionHandler( userService services.IUserService, mailService services.IMailService, keyValueService services.IKeyValueService, ) *SubscriptionHandler { config := conf.Get() if config.Subscriptions.Enabled { stripe.Key = config.Subscriptions.StripeSecretKey price, err := stripePrice.Get(config.Subscriptions.StandardPriceId, nil) if err != nil { logbuch.Fatal("failed to fetch stripe plan details: %v", err) } config.Subscriptions.StandardPrice = strings.TrimSpace(fmt.Sprintf("%2.f €", price.UnitAmountDecimal/100.0)) // TODO: respect actual currency logbuch.Info("enabling subscriptions with stripe payment for %s / month", config.Subscriptions.StandardPrice) } return &SubscriptionHandler{ config: config, userSrvc: userService, mailSrvc: mailService, keyValueSrvc: keyValueService, httpClient: &http.Client{Timeout: 10 * time.Second}, } } // https://stripe.com/docs/billing/quickstart?lang=go func (h *SubscriptionHandler) RegisterRoutes(router chi.Router) { if !h.config.Subscriptions.Enabled { return } subRouterPublic := chi.NewRouter() subRouterPublic.Get("/success", h.GetCheckoutSuccess) subRouterPublic.Get("/cancel", h.GetCheckoutCancel) subRouterPublic.Post("/webhook", h.PostWebhook) subRouterPrivate := chi.NewRouter() subRouterPrivate.Use( middlewares.NewAuthenticateMiddleware(h.userSrvc). WithRedirectTarget(defaultErrorRedirectTarget()). WithRedirectErrorMessage("unauthorized").Handler, ) subRouterPrivate.Post("/checkout", h.PostCheckout) subRouterPrivate.Post("/portal", h.PostPortal) subRouterPublic.Mount("/", subRouterPrivate) router.Mount("/subscription", subRouterPublic) } func (h *SubscriptionHandler) PostCheckout(w http.ResponseWriter, r *http.Request) { if h.config.IsDev() { loadTemplates() } user := middlewares.GetPrincipal(r) if user.Email == "" { routeutils.SetError(r, w, "missing e-mail address") http.Redirect(w, r, fmt.Sprintf("%s/settings#subscription", h.config.Server.BasePath), http.StatusFound) return } if err := r.ParseForm(); err != nil { routeutils.SetError(r, w, "missing form values") http.Redirect(w, r, fmt.Sprintf("%s/settings#subscription", h.config.Server.BasePath), http.StatusFound) return } checkoutParams := &stripe.CheckoutSessionParams{ Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)), LineItems: []*stripe.CheckoutSessionLineItemParams{ { Price: &h.config.Subscriptions.StandardPriceId, Quantity: stripe.Int64(1), }, }, ClientReferenceID: &user.ID, SuccessURL: stripe.String(fmt.Sprintf("%s%s/subscription/success", h.config.Server.PublicUrl, h.config.Server.BasePath)), CancelURL: stripe.String(fmt.Sprintf("%s%s/subscription/cancel", h.config.Server.PublicUrl, h.config.Server.BasePath)), } if user.StripeCustomerId != "" { checkoutParams.Customer = &user.StripeCustomerId } else { checkoutParams.CustomerEmail = &user.Email } session, err := stripeCheckoutSession.New(checkoutParams) if err != nil { conf.Log().Request(r).Error("failed to create stripe checkout session: %v", err) routeutils.SetError(r, w, "something went wrong") http.Redirect(w, r, fmt.Sprintf("%s/settings#subscription", h.config.Server.BasePath), http.StatusFound) return } http.Redirect(w, r, session.URL, http.StatusSeeOther) } func (h *SubscriptionHandler) PostPortal(w http.ResponseWriter, r *http.Request) { if h.config.IsDev() { loadTemplates() } user := middlewares.GetPrincipal(r) if user.StripeCustomerId == "" { routeutils.SetError(r, w, "no subscription found with your e-mail address, please contact us!") http.Redirect(w, r, fmt.Sprintf("%s/settings#subscription", h.config.Server.BasePath), http.StatusFound) return } portalParams := &stripe.BillingPortalSessionParams{ Customer: &user.StripeCustomerId, ReturnURL: &h.config.Server.PublicUrl, } session, err := stripePortalSession.New(portalParams) if err != nil { conf.Log().Request(r).Error("failed to create stripe portal session: %v", err) routeutils.SetError(r, w, "something went wrong") http.Redirect(w, r, fmt.Sprintf("%s/settings#subscription", h.config.Server.BasePath), http.StatusFound) return } http.Redirect(w, r, session.URL, http.StatusSeeOther) } func (h *SubscriptionHandler) PostWebhook(w http.ResponseWriter, r *http.Request) { bodyReader := http.MaxBytesReader(w, r.Body, int64(65536)) payload, err := ioutil.ReadAll(bodyReader) if err != nil { conf.Log().Request(r).Error("error in stripe webhook request: %v", err) w.WriteHeader(http.StatusServiceUnavailable) return } event, err := webhook.ConstructEventWithOptions(payload, r.Header.Get("Stripe-Signature"), h.config.Subscriptions.StripeEndpointSecret, webhook.ConstructEventOptions{ IgnoreAPIVersionMismatch: true, }) if err != nil { conf.Log().Request(r).Error("stripe webhook signature verification failed: %v", err) w.WriteHeader(http.StatusBadRequest) return } switch event.Type { case "customer.subscription.deleted", "customer.subscription.updated", "customer.subscription.created": // example payload: https://pastr.de/p/k7bx3alx38b1iawo6amtx09k subscription, err := h.parseSubscriptionEvent(w, r, event) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } logbuch.Info("received stripe subscription event of type '%s' for subscription '%s' (customer '%s').", event.Type, subscription.ID, subscription.Customer.ID) // first, try to get user by associated customer id (requires checkout.session.completed event to have been processed before) user, err := h.userSrvc.GetUserByStripeCustomerId(subscription.Customer.ID) if err != nil { conf.Log().Request(r).Warn("failed to find user with stripe customer id '%s' to update their subscription (status '%s')", subscription.Customer.ID, subscription.Status) // second, resolve customer and try to get user by email customer, err := stripeCustomer.Get(subscription.Customer.ID, nil) if err != nil { conf.Log().Request(r).Error("failed to fetch stripe customer with id '%s', %v", subscription.Customer.ID, err) w.WriteHeader(http.StatusInternalServerError) return } u, err := h.userSrvc.GetUserByEmail(customer.Email) if err != nil { conf.Log().Request(r).Error("failed to get user with email '%s' as stripe customer '%s' for processing event for subscription %s, %v", customer.Email, subscription.Customer.ID, subscription.ID, err) w.WriteHeader(http.StatusInternalServerError) return } user = u } if err := h.handleSubscriptionEvent(subscription, user); err != nil { conf.Log().Request(r).Error("failed to handle subscription event %s (%s) for user %s, %v", event.ID, event.Type, user.ID, err) w.WriteHeader(http.StatusInternalServerError) return } case "checkout.session.completed": // example payload: https://pastr.de/p/d01iniw9naq9hkmvyqtxin2w checkoutSession, err := h.parseCheckoutSessionEvent(w, r, event) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } logbuch.Info("received stripe checkout session event of type '%s' for session '%s' (customer '%s' with email '%s').", event.Type, checkoutSession.ID, checkoutSession.Customer.ID, checkoutSession.CustomerEmail) user, err := h.userSrvc.GetUserById(checkoutSession.ClientReferenceID) if err != nil { conf.Log().Request(r).Error("failed to find user with id '%s' to update associated stripe customer (%s)", user.ID, checkoutSession.Customer.ID) w.WriteHeader(http.StatusInternalServerError) return } if user.StripeCustomerId == "" { user.StripeCustomerId = checkoutSession.Customer.ID if _, err := h.userSrvc.Update(user); err != nil { conf.Log().Request(r).Error("failed to update stripe customer id (%s) for user '%s', %v", checkoutSession.Customer.ID, user.ID, err) } else { logbuch.Info("associated user '%s' with stripe customer '%s'", user.ID, checkoutSession.Customer.ID) } } else if user.StripeCustomerId != checkoutSession.Customer.ID { conf.Log().Request(r).Error("invalid state: tried to associate user '%s' with stripe customer '%s', but '%s' already assigned", user.ID, checkoutSession.Customer.ID, user.StripeCustomerId) } default: logbuch.Warn("got stripe event '%s' with no handler defined", event.Type) } w.WriteHeader(http.StatusOK) } func (h *SubscriptionHandler) GetCheckoutSuccess(w http.ResponseWriter, r *http.Request) { routeutils.SetSuccess(r, w, "you have successfully subscribed to Wakapi!") http.Redirect(w, r, fmt.Sprintf("%s/settings", h.config.Server.BasePath), http.StatusFound) } func (h *SubscriptionHandler) GetCheckoutCancel(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, fmt.Sprintf("%s/settings#subscription", h.config.Server.BasePath), http.StatusFound) } func (h *SubscriptionHandler) handleSubscriptionEvent(subscription *stripe.Subscription, user *models.User) error { var hasSubscribed bool switch subscription.Status { case "active": until := models.CustomTime(time.Unix(subscription.CurrentPeriodEnd, 0)) if user.SubscribedUntil == nil || !user.SubscribedUntil.T().Equal(until.T()) { hasSubscribed = true user.SubscribedUntil = &until user.SubscriptionRenewal = &until logbuch.Info("user %s got active subscription %s until %v", user.ID, subscription.ID, user.SubscribedUntil) } if cancelAt := time.Unix(subscription.CancelAt, 0); !cancelAt.IsZero() && cancelAt.After(time.Now()) { user.SubscriptionRenewal = nil logbuch.Info("user %s chose to cancel subscription %s by %v", user.ID, subscription.ID, cancelAt) } case "canceled", "unpaid", "incomplete_expired": user.SubscribedUntil = nil user.SubscriptionRenewal = nil logbuch.Info("user %s's subscription %s got canceled, because of status update to '%s'", user.ID, subscription.ID, subscription.Status) default: logbuch.Info("got subscription (%s) status update to '%s' for user '%s'", subscription.ID, subscription.Status, user.ID) return nil } _, err := h.userSrvc.Update(user) if err == nil && hasSubscribed { go h.clearSubscriptionNotificationStatus(user.ID) } return err } func (h *SubscriptionHandler) parseSubscriptionEvent(w http.ResponseWriter, r *http.Request, event stripe.Event) (*stripe.Subscription, error) { var subscription stripe.Subscription if err := json.Unmarshal(event.Data.Raw, &subscription); err != nil { conf.Log().Request(r).Error("failed to parse stripe webhook payload: %v", err) w.WriteHeader(http.StatusBadRequest) return nil, err } return &subscription, nil } func (h *SubscriptionHandler) parseCheckoutSessionEvent(w http.ResponseWriter, r *http.Request, event stripe.Event) (*stripe.CheckoutSession, error) { var checkoutSession stripe.CheckoutSession if err := json.Unmarshal(event.Data.Raw, &checkoutSession); err != nil { conf.Log().Request(r).Error("failed to parse stripe webhook payload: %v", err) w.WriteHeader(http.StatusBadRequest) return nil, err } return &checkoutSession, nil } func (h *SubscriptionHandler) findStripeCustomerByEmail(email string) (*stripe.Customer, error) { params := &stripe.CustomerSearchParams{ SearchParams: stripe.SearchParams{ Query: fmt.Sprintf(`email:"%s"`, email), }, } results := stripeCustomer.Search(params) if err := results.Err(); err != nil { return nil, err } if results.Next() { return results.Customer(), nil } else { return nil, errors.New("no customer found with given criteria") } } func (h *SubscriptionHandler) clearSubscriptionNotificationStatus(userId string) { key := fmt.Sprintf("%s_%s", conf.KeySubscriptionNotificationSent, userId) if err := h.keyValueSrvc.DeleteString(key); err != nil { logbuch.Warn("failed to delete '%s', %v", key, err) } }