7.4 KB
payments.go
package payments

import (
	"io"
	"log"
	"os"

	"github.com/readysite/readysite/readysite.org/models"
	"github.com/stripe/stripe-go/v82"
	"github.com/stripe/stripe-go/v82/checkout/session"
	"github.com/stripe/stripe-go/v82/customer"
	portalsession "github.com/stripe/stripe-go/v82/billingportal/session"
	"github.com/stripe/stripe-go/v82/webhook"
)

// Plan represents a billing plan.
type Plan struct {
	ID    string // "free", "hobby", or "pro"
	Name  string
	Price int // cents per month
}

var (
	FreePlan  = Plan{ID: "free", Name: "Free", Price: 0}
	HobbyPlan = Plan{ID: "hobby", Name: "Hobby", Price: 500}  // $5.00
	ProPlan   = Plan{ID: "pro", Name: "Pro", Price: 2000}      // $20.00
)

// Environment variables for Stripe configuration.
var (
	stripeAPIKey        string
	stripeWebhookSecret string
	hobbyPriceID        string
	proPriceID          string
)

func init() {
	stripeAPIKey = os.Getenv("STRIPE_API_KEY")
	stripeWebhookSecret = os.Getenv("STRIPE_WEBHOOK_SECRET")
	hobbyPriceID = os.Getenv("STRIPE_HOBBY_PRICE_ID")
	proPriceID = os.Getenv("STRIPE_PRO_PRICE_ID")

	if stripeAPIKey != "" {
		stripe.Key = stripeAPIKey
	}
}

// Enabled returns true if Stripe is configured.
func Enabled() bool {
	return stripeAPIKey != ""
}

// PriceIDForPlan returns the Stripe Price ID for the given plan.
func PriceIDForPlan(plan string) string {
	switch plan {
	case "hobby":
		return hobbyPriceID
	case "pro":
		return proPriceID
	default:
		return ""
	}
}

// PlanForPriceID returns the plan name for a Stripe Price ID.
func PlanForPriceID(priceID string) string {
	switch priceID {
	case hobbyPriceID:
		return "hobby"
	case proPriceID:
		return "pro"
	default:
		return ""
	}
}

// EnsureCustomer creates or retrieves a Stripe customer for the user.
// If the user already has a StripeCustomerID, it returns that.
// Otherwise it creates a new customer and saves the ID.
func EnsureCustomer(user *models.User) (string, error) {
	if user.StripeCustomerID != "" {
		return user.StripeCustomerID, nil
	}

	params := &stripe.CustomerParams{
		Email: stripe.String(user.Email),
		Name:  stripe.String(user.Name),
		Params: stripe.Params{
			Metadata: map[string]string{
				"user_id": user.ID,
			},
		},
	}
	c, err := customer.New(params)
	if err != nil {
		return "", err
	}

	user.StripeCustomerID = c.ID
	if err := models.Users.Update(user); err != nil {
		return "", err
	}

	return c.ID, nil
}

// CreateCheckoutSession creates a Stripe Checkout session for upgrading a site.
// Returns the checkout session URL.
func CreateCheckoutSession(customerID, priceID, siteID, successURL, cancelURL string) (string, error) {
	params := &stripe.CheckoutSessionParams{
		Customer: stripe.String(customerID),
		Mode:     stripe.String(string(stripe.CheckoutSessionModeSubscription)),
		LineItems: []*stripe.CheckoutSessionLineItemParams{
			{
				Price:    stripe.String(priceID),
				Quantity: stripe.Int64(1),
			},
		},
		SuccessURL: stripe.String(successURL),
		CancelURL:  stripe.String(cancelURL),
		SubscriptionData: &stripe.CheckoutSessionSubscriptionDataParams{
			Metadata: map[string]string{
				"site_id": siteID,
			},
		},
		Params: stripe.Params{
			Metadata: map[string]string{
				"site_id": siteID,
			},
		},
	}

	s, err := session.New(params)
	if err != nil {
		return "", err
	}
	return s.URL, nil
}

// CreatePortalSession creates a Stripe Customer Portal session for self-service billing.
// Returns the portal session URL.
func CreatePortalSession(customerID, returnURL string) (string, error) {
	params := &stripe.BillingPortalSessionParams{
		Customer:  stripe.String(customerID),
		ReturnURL: stripe.String(returnURL),
	}
	s, err := portalsession.New(params)
	if err != nil {
		return "", err
	}
	return s.URL, nil
}

// HandleCheckoutCompleted processes a checkout.session.completed event.
// It looks up the site from metadata and updates its plan and subscription info.
func HandleCheckoutCompleted(sess *stripe.CheckoutSession) {
	siteID := sess.Metadata["site_id"]
	if siteID == "" {
		log.Printf("[payments] checkout.session.completed: no site_id in metadata")
		return
	}

	site, err := models.Sites.Get(siteID)
	if err != nil || site == nil {
		log.Printf("[payments] checkout.session.completed: site %s not found: %v", siteID, err)
		return
	}

	if sess.Subscription != nil {
		site.StripeSubscriptionID = sess.Subscription.ID
	}

	// Determine plan from the subscription's price
	if sess.Subscription != nil && len(sess.Subscription.Items.Data) > 0 {
		priceID := sess.Subscription.Items.Data[0].Price.ID
		site.StripePriceID = priceID
		plan := PlanForPriceID(priceID)
		if plan != "" {
			site.Plan = plan
		}
	}

	if err := models.Sites.Update(site); err != nil {
		log.Printf("[payments] checkout.session.completed: failed to update site %s: %v", siteID, err)
	}
	log.Printf("[payments] checkout.session.completed: site %s upgraded to %s", siteID, site.Plan)
}

// HandleSubscriptionChange processes a customer.subscription.updated event.
// It syncs the site's plan from the subscription's current price.
func HandleSubscriptionChange(sub *stripe.Subscription) {
	siteID := sub.Metadata["site_id"]
	if siteID == "" {
		log.Printf("[payments] subscription.updated: no site_id in metadata")
		return
	}

	site, err := models.Sites.Get(siteID)
	if err != nil || site == nil {
		log.Printf("[payments] subscription.updated: site %s not found: %v", siteID, err)
		return
	}

	if len(sub.Items.Data) > 0 {
		priceID := sub.Items.Data[0].Price.ID
		site.StripePriceID = priceID
		plan := PlanForPriceID(priceID)
		if plan != "" {
			site.Plan = plan
		}
	}

	site.StripeSubscriptionID = sub.ID
	if err := models.Sites.Update(site); err != nil {
		log.Printf("[payments] subscription.updated: failed to update site %s: %v", siteID, err)
	}
	log.Printf("[payments] subscription.updated: site %s plan=%s", siteID, site.Plan)
}

// HandleSubscriptionDeleted processes a customer.subscription.deleted event.
// It downgrades the site to the free plan.
func HandleSubscriptionDeleted(sub *stripe.Subscription) {
	siteID := sub.Metadata["site_id"]
	if siteID == "" {
		log.Printf("[payments] subscription.deleted: no site_id in metadata")
		return
	}

	site, err := models.Sites.Get(siteID)
	if err != nil || site == nil {
		log.Printf("[payments] subscription.deleted: site %s not found: %v", siteID, err)
		return
	}

	site.Plan = "free"
	site.StripeSubscriptionID = ""
	site.StripePriceID = ""
	if err := models.Sites.Update(site); err != nil {
		log.Printf("[payments] subscription.deleted: failed to update site %s: %v", siteID, err)
	}
	log.Printf("[payments] subscription.deleted: site %s downgraded to free", siteID)
}

// ConstructWebhookEvent verifies and parses a Stripe webhook event.
func ConstructWebhookEvent(payload []byte, sig string) (stripe.Event, error) {
	return webhook.ConstructEvent(payload, sig, stripeWebhookSecret)
}

// ConstructWebhookEventFromReader reads the body and verifies a Stripe webhook event.
func ConstructWebhookEventFromReader(body io.Reader, sig string) (stripe.Event, error) {
	payload, err := io.ReadAll(io.LimitReader(body, 65536))
	if err != nil {
		return stripe.Event{}, err
	}
	return ConstructWebhookEvent(payload, sig)
}

// --- Dev mode stubs (used when Stripe is not configured) ---

// UpgradeToHobby upgrades a site from Free to Hobby (DB-only stub for dev mode).
func UpgradeToHobby(site *models.Site) error {
	site.Plan = "hobby"
	return models.Sites.Update(site)
}

// UpgradeToPro upgrades a site to Pro (DB-only stub for dev mode).
func UpgradeToPro(site *models.Site) error {
	site.Plan = "pro"
	return models.Sites.Update(site)
}
← Back