7.4 KB
payments.go
package payments
import (
"io"
"log"
"os"
"github.com/readysite/readysite/readysite.org/models"
"github.com/stripe/stripe-go/v82"
"github.com/stripe/stripe-go/v82/checkout/session"
"github.com/stripe/stripe-go/v82/customer"
portalsession "github.com/stripe/stripe-go/v82/billingportal/session"
"github.com/stripe/stripe-go/v82/webhook"
)
// Plan represents a billing plan.
type Plan struct {
ID string // "free", "hobby", or "pro"
Name string
Price int // cents per month
}
var (
FreePlan = Plan{ID: "free", Name: "Free", Price: 0}
HobbyPlan = Plan{ID: "hobby", Name: "Hobby", Price: 500} // $5.00
ProPlan = Plan{ID: "pro", Name: "Pro", Price: 2000} // $20.00
)
// Environment variables for Stripe configuration.
var (
stripeAPIKey string
stripeWebhookSecret string
hobbyPriceID string
proPriceID string
)
func init() {
stripeAPIKey = os.Getenv("STRIPE_API_KEY")
stripeWebhookSecret = os.Getenv("STRIPE_WEBHOOK_SECRET")
hobbyPriceID = os.Getenv("STRIPE_HOBBY_PRICE_ID")
proPriceID = os.Getenv("STRIPE_PRO_PRICE_ID")
if stripeAPIKey != "" {
stripe.Key = stripeAPIKey
}
}
// Enabled returns true if Stripe is configured.
func Enabled() bool {
return stripeAPIKey != ""
}
// PriceIDForPlan returns the Stripe Price ID for the given plan.
func PriceIDForPlan(plan string) string {
switch plan {
case "hobby":
return hobbyPriceID
case "pro":
return proPriceID
default:
return ""
}
}
// PlanForPriceID returns the plan name for a Stripe Price ID.
func PlanForPriceID(priceID string) string {
switch priceID {
case hobbyPriceID:
return "hobby"
case proPriceID:
return "pro"
default:
return ""
}
}
// EnsureCustomer creates or retrieves a Stripe customer for the user.
// If the user already has a StripeCustomerID, it returns that.
// Otherwise it creates a new customer and saves the ID.
func EnsureCustomer(user *models.User) (string, error) {
if user.StripeCustomerID != "" {
return user.StripeCustomerID, nil
}
params := &stripe.CustomerParams{
Email: stripe.String(user.Email),
Name: stripe.String(user.Name),
Params: stripe.Params{
Metadata: map[string]string{
"user_id": user.ID,
},
},
}
c, err := customer.New(params)
if err != nil {
return "", err
}
user.StripeCustomerID = c.ID
if err := models.Users.Update(user); err != nil {
return "", err
}
return c.ID, nil
}
// CreateCheckoutSession creates a Stripe Checkout session for upgrading a site.
// Returns the checkout session URL.
func CreateCheckoutSession(customerID, priceID, siteID, successURL, cancelURL string) (string, error) {
params := &stripe.CheckoutSessionParams{
Customer: stripe.String(customerID),
Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)),
LineItems: []*stripe.CheckoutSessionLineItemParams{
{
Price: stripe.String(priceID),
Quantity: stripe.Int64(1),
},
},
SuccessURL: stripe.String(successURL),
CancelURL: stripe.String(cancelURL),
SubscriptionData: &stripe.CheckoutSessionSubscriptionDataParams{
Metadata: map[string]string{
"site_id": siteID,
},
},
Params: stripe.Params{
Metadata: map[string]string{
"site_id": siteID,
},
},
}
s, err := session.New(params)
if err != nil {
return "", err
}
return s.URL, nil
}
// CreatePortalSession creates a Stripe Customer Portal session for self-service billing.
// Returns the portal session URL.
func CreatePortalSession(customerID, returnURL string) (string, error) {
params := &stripe.BillingPortalSessionParams{
Customer: stripe.String(customerID),
ReturnURL: stripe.String(returnURL),
}
s, err := portalsession.New(params)
if err != nil {
return "", err
}
return s.URL, nil
}
// HandleCheckoutCompleted processes a checkout.session.completed event.
// It looks up the site from metadata and updates its plan and subscription info.
func HandleCheckoutCompleted(sess *stripe.CheckoutSession) {
siteID := sess.Metadata["site_id"]
if siteID == "" {
log.Printf("[payments] checkout.session.completed: no site_id in metadata")
return
}
site, err := models.Sites.Get(siteID)
if err != nil || site == nil {
log.Printf("[payments] checkout.session.completed: site %s not found: %v", siteID, err)
return
}
if sess.Subscription != nil {
site.StripeSubscriptionID = sess.Subscription.ID
}
// Determine plan from the subscription's price
if sess.Subscription != nil && len(sess.Subscription.Items.Data) > 0 {
priceID := sess.Subscription.Items.Data[0].Price.ID
site.StripePriceID = priceID
plan := PlanForPriceID(priceID)
if plan != "" {
site.Plan = plan
}
}
if err := models.Sites.Update(site); err != nil {
log.Printf("[payments] checkout.session.completed: failed to update site %s: %v", siteID, err)
}
log.Printf("[payments] checkout.session.completed: site %s upgraded to %s", siteID, site.Plan)
}
// HandleSubscriptionChange processes a customer.subscription.updated event.
// It syncs the site's plan from the subscription's current price.
func HandleSubscriptionChange(sub *stripe.Subscription) {
siteID := sub.Metadata["site_id"]
if siteID == "" {
log.Printf("[payments] subscription.updated: no site_id in metadata")
return
}
site, err := models.Sites.Get(siteID)
if err != nil || site == nil {
log.Printf("[payments] subscription.updated: site %s not found: %v", siteID, err)
return
}
if len(sub.Items.Data) > 0 {
priceID := sub.Items.Data[0].Price.ID
site.StripePriceID = priceID
plan := PlanForPriceID(priceID)
if plan != "" {
site.Plan = plan
}
}
site.StripeSubscriptionID = sub.ID
if err := models.Sites.Update(site); err != nil {
log.Printf("[payments] subscription.updated: failed to update site %s: %v", siteID, err)
}
log.Printf("[payments] subscription.updated: site %s plan=%s", siteID, site.Plan)
}
// HandleSubscriptionDeleted processes a customer.subscription.deleted event.
// It downgrades the site to the free plan.
func HandleSubscriptionDeleted(sub *stripe.Subscription) {
siteID := sub.Metadata["site_id"]
if siteID == "" {
log.Printf("[payments] subscription.deleted: no site_id in metadata")
return
}
site, err := models.Sites.Get(siteID)
if err != nil || site == nil {
log.Printf("[payments] subscription.deleted: site %s not found: %v", siteID, err)
return
}
site.Plan = "free"
site.StripeSubscriptionID = ""
site.StripePriceID = ""
if err := models.Sites.Update(site); err != nil {
log.Printf("[payments] subscription.deleted: failed to update site %s: %v", siteID, err)
}
log.Printf("[payments] subscription.deleted: site %s downgraded to free", siteID)
}
// ConstructWebhookEvent verifies and parses a Stripe webhook event.
func ConstructWebhookEvent(payload []byte, sig string) (stripe.Event, error) {
return webhook.ConstructEvent(payload, sig, stripeWebhookSecret)
}
// ConstructWebhookEventFromReader reads the body and verifies a Stripe webhook event.
func ConstructWebhookEventFromReader(body io.Reader, sig string) (stripe.Event, error) {
payload, err := io.ReadAll(io.LimitReader(body, 65536))
if err != nil {
return stripe.Event{}, err
}
return ConstructWebhookEvent(payload, sig)
}
// --- Dev mode stubs (used when Stripe is not configured) ---
// UpgradeToHobby upgrades a site from Free to Hobby (DB-only stub for dev mode).
func UpgradeToHobby(site *models.Site) error {
site.Plan = "hobby"
return models.Sites.Update(site)
}
// UpgradeToPro upgrades a site to Pro (DB-only stub for dev mode).
func UpgradeToPro(site *models.Site) error {
site.Plan = "pro"
return models.Sites.Update(site)
}