6.0 KB
billing.go
package controllers

import (
	"encoding/json"
	"log"
	"net/http"

	"github.com/readysite/readysite/readysite.org/internal/access"
	"github.com/readysite/readysite/readysite.org/internal/payments"
	"github.com/readysite/readysite/readysite.org/models"
	"github.com/readysite/readysite/pkg/application"
	"github.com/stripe/stripe-go/v82"
)

// Billing returns the billing controller.
func Billing() (string, *BillingController) {
	return "billing", &BillingController{}
}

// BillingController handles the billing and plan page.
type BillingController struct {
	application.BaseController
}

// Setup registers routes.
func (c *BillingController) Setup(app *application.App) {
	c.BaseController.Setup(app)
	http.Handle("GET /billing", app.Serve("billing.html", RequireAuth))
	http.Handle("POST /billing/checkout", app.Method(c, "CreateCheckout", RequireAuthAPI))
	http.Handle("POST /billing/portal", app.Method(c, "OpenPortal", RequireAuth))
	http.Handle("POST /billing/webhook", app.Method(c, "HandleWebhook", nil))
}

// Handle returns a request-scoped controller instance.
func (c BillingController) Handle(r *http.Request) application.Controller {
	c.Request = r
	return &c
}

// UserSites returns the current user's non-deleted sites.
func (c *BillingController) UserSites() []*models.Site {
	user := access.GetUserFromJWT(c.Request)
	if user == nil {
		return nil
	}
	sites, err := models.Sites.Search("WHERE UserID = ? AND Status != 'deleted' ORDER BY CreatedAt DESC", user.ID)
	if err != nil {
		return nil
	}
	return sites
}

// StripeEnabled returns whether Stripe is configured.
func (c *BillingController) StripeEnabled() bool {
	return payments.Enabled()
}

// HasStripeCustomer returns whether the current user has a Stripe customer ID.
func (c *BillingController) HasStripeCustomer() bool {
	user := access.GetUserFromJWT(c.Request)
	if user == nil {
		return false
	}
	return user.StripeCustomerID != ""
}

// CreateCheckout creates a Stripe Checkout session and returns the URL as JSON.
func (c *BillingController) CreateCheckout(w http.ResponseWriter, r *http.Request) {
	if !payments.Enabled() {
		jsonError(w, "Payments not configured", http.StatusServiceUnavailable)
		return
	}

	user := access.GetUserFromJWT(r)

	siteID := r.FormValue("site_id")
	plan := r.FormValue("plan")
	if siteID == "" || plan == "" {
		jsonError(w, "Missing site_id or plan", http.StatusBadRequest)
		return
	}

	// Verify site ownership
	site, err := models.Sites.Get(siteID)
	if err != nil || site == nil || site.UserID != user.ID {
		jsonError(w, "Site not found", http.StatusNotFound)
		return
	}

	priceID := payments.PriceIDForPlan(plan)
	if priceID == "" {
		jsonError(w, "Invalid plan", http.StatusBadRequest)
		return
	}

	customerID, err := payments.EnsureCustomer(user)
	if err != nil {
		log.Printf("[billing] EnsureCustomer failed for user %s: %v", user.ID, err)
		jsonError(w, "Failed to set up billing", http.StatusInternalServerError)
		return
	}

	scheme := "https"
	if r.TLS == nil && r.Header.Get("X-Forwarded-Proto") != "https" {
		scheme = "http"
	}
	baseURL := scheme + "://" + r.Host
	successURL := baseURL + "/billing?success=true"
	cancelURL := baseURL + "/billing?canceled=true"

	checkoutURL, err := payments.CreateCheckoutSession(customerID, priceID, siteID, successURL, cancelURL)
	if err != nil {
		log.Printf("[billing] CreateCheckoutSession failed for site %s: %v", siteID, err)
		jsonError(w, "Failed to create checkout session", http.StatusInternalServerError)
		return
	}

	jsonResponse(w, map[string]string{"checkout_url": checkoutURL})
}

// OpenPortal creates a Stripe Customer Portal session and redirects to it.
func (c *BillingController) OpenPortal(w http.ResponseWriter, r *http.Request) {
	if !payments.Enabled() {
		http.Redirect(w, r, "/billing", http.StatusSeeOther)
		return
	}

	user := access.GetUserFromJWT(r)
	if user.StripeCustomerID == "" {
		http.Redirect(w, r, "/billing", http.StatusSeeOther)
		return
	}

	scheme := "https"
	if r.TLS == nil && r.Header.Get("X-Forwarded-Proto") != "https" {
		scheme = "http"
	}
	returnURL := scheme + "://" + r.Host + "/billing"

	portalURL, err := payments.CreatePortalSession(user.StripeCustomerID, returnURL)
	if err != nil {
		log.Printf("[billing] CreatePortalSession failed for user %s: %v", user.ID, err)
		http.Redirect(w, r, "/billing", http.StatusSeeOther)
		return
	}

	http.Redirect(w, r, portalURL, http.StatusSeeOther)
}

// HandleWebhook processes Stripe webhook events.
func (c *BillingController) HandleWebhook(w http.ResponseWriter, r *http.Request) {
	sig := r.Header.Get("Stripe-Signature")
	if sig == "" {
		http.Error(w, "Missing signature", http.StatusBadRequest)
		return
	}

	event, err := payments.ConstructWebhookEventFromReader(r.Body, sig)
	if err != nil {
		log.Printf("[billing] Webhook signature verification failed: %v", err)
		http.Error(w, "Invalid signature", http.StatusBadRequest)
		return
	}

	switch event.Type {
	case "checkout.session.completed":
		var sess stripe.CheckoutSession
		if err := json.Unmarshal(event.Data.Raw, &sess); err != nil {
			log.Printf("[billing] Failed to parse checkout.session.completed: %v", err)
			http.Error(w, "Parse error", http.StatusBadRequest)
			return
		}
		payments.HandleCheckoutCompleted(&sess)

	case "customer.subscription.updated":
		var sub stripe.Subscription
		if err := json.Unmarshal(event.Data.Raw, &sub); err != nil {
			log.Printf("[billing] Failed to parse customer.subscription.updated: %v", err)
			http.Error(w, "Parse error", http.StatusBadRequest)
			return
		}
		payments.HandleSubscriptionChange(&sub)

	case "customer.subscription.deleted":
		var sub stripe.Subscription
		if err := json.Unmarshal(event.Data.Raw, &sub); err != nil {
			log.Printf("[billing] Failed to parse customer.subscription.deleted: %v", err)
			http.Error(w, "Parse error", http.StatusBadRequest)
			return
		}
		payments.HandleSubscriptionDeleted(&sub)

	case "invoice.payment_failed":
		log.Printf("[billing] invoice.payment_failed: %s", string(event.Data.Raw))

	default:
		log.Printf("[billing] Unhandled webhook event: %s", event.Type)
	}

	w.WriteHeader(http.StatusOK)
}
← Back