Write on the margins of the internet. Powered by the AT Protocol. margin.at
extension web atproto comments
at ui-refactor 618 lines 17 kB view raw
1package oauth 2 3import ( 4 "context" 5 "crypto/ecdsa" 6 "crypto/elliptic" 7 "crypto/rand" 8 "crypto/x509" 9 "encoding/json" 10 "encoding/pem" 11 "fmt" 12 "log" 13 "net/http" 14 "net/url" 15 "os" 16 "sync" 17 "time" 18 19 "margin.at/internal/db" 20 internal_sync "margin.at/internal/sync" 21 "margin.at/internal/xrpc" 22) 23 24type Handler struct { 25 db *db.DB 26 configuredBaseURL string 27 privateKey *ecdsa.PrivateKey 28 pending map[string]*PendingAuth 29 pendingMu sync.RWMutex 30 syncService *internal_sync.Service 31} 32 33func NewHandler(database *db.DB, syncService *internal_sync.Service) (*Handler, error) { 34 35 configuredBaseURL := os.Getenv("BASE_URL") 36 37 privateKey, err := loadOrGenerateKey() 38 if err != nil { 39 return nil, fmt.Errorf("failed to load/generate key: %w", err) 40 } 41 42 return &Handler{ 43 db: database, 44 configuredBaseURL: configuredBaseURL, 45 privateKey: privateKey, 46 pending: make(map[string]*PendingAuth), 47 syncService: syncService, 48 }, nil 49} 50 51func loadOrGenerateKey() (*ecdsa.PrivateKey, error) { 52 keyPath := os.Getenv("OAUTH_KEY_PATH") 53 if keyPath == "" { 54 keyPath = "./oauth_private_key.pem" 55 } 56 57 if data, err := os.ReadFile(keyPath); err == nil { 58 block, _ := pem.Decode(data) 59 if block != nil { 60 key, err := x509.ParseECPrivateKey(block.Bytes) 61 if err == nil { 62 return key, nil 63 } 64 } 65 } 66 67 key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 68 if err != nil { 69 return nil, err 70 } 71 72 keyBytes, err := x509.MarshalECPrivateKey(key) 73 if err != nil { 74 return nil, err 75 } 76 77 block := &pem.Block{ 78 Type: "EC PRIVATE KEY", 79 Bytes: keyBytes, 80 } 81 82 if err := os.WriteFile(keyPath, pem.EncodeToMemory(block), 0600); err != nil { 83 log.Printf("Warning: could not save key to %s: %v\n", keyPath, err) 84 } 85 86 return key, nil 87} 88 89func (h *Handler) getDynamicClient(r *http.Request) *Client { 90 baseURL := h.configuredBaseURL 91 if baseURL == "" { 92 scheme := "http" 93 if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" { 94 scheme = "https" 95 } 96 baseURL = fmt.Sprintf("%s://%s", scheme, r.Host) 97 } 98 99 if len(baseURL) > 0 && baseURL[len(baseURL)-1] == '/' { 100 baseURL = baseURL[:len(baseURL)-1] 101 } 102 103 clientID := baseURL + "/client-metadata.json" 104 redirectURI := baseURL + "/auth/callback" 105 106 return NewClient(clientID, redirectURI, h.privateKey) 107} 108 109func (h *Handler) HandleLogin(w http.ResponseWriter, r *http.Request) { 110 client := h.getDynamicClient(r) 111 112 handle := r.URL.Query().Get("handle") 113 if handle == "" { 114 http.Redirect(w, r, "/login", http.StatusFound) 115 return 116 } 117 118 ctx := r.Context() 119 120 did, err := client.ResolveHandle(ctx, handle) 121 if err != nil { 122 http.Error(w, fmt.Sprintf("Failed to resolve handle: %v", err), http.StatusBadRequest) 123 return 124 } 125 126 pds, err := client.ResolveDIDToPDS(ctx, did) 127 if err != nil { 128 http.Error(w, fmt.Sprintf("Failed to resolve PDS: %v", err), http.StatusBadRequest) 129 return 130 } 131 132 meta, err := client.GetAuthServerMetadata(ctx, pds) 133 if err != nil { 134 http.Error(w, fmt.Sprintf("Failed to get auth server metadata: %v", err), http.StatusBadRequest) 135 return 136 } 137 138 dpopKey, err := client.GenerateDPoPKey() 139 if err != nil { 140 http.Error(w, fmt.Sprintf("Failed to generate DPoP key: %v", err), http.StatusInternalServerError) 141 return 142 } 143 144 pkceVerifier, pkceChallenge := client.GeneratePKCE() 145 146 scope := "atproto offline_access blob:* include:at.margin.authFull" 147 148 parResp, state, dpopNonce, err := client.SendPAR(meta, handle, scope, dpopKey, pkceChallenge) 149 if err != nil { 150 http.Error(w, fmt.Sprintf("PAR request failed: %v", err), http.StatusInternalServerError) 151 return 152 } 153 154 pending := &PendingAuth{ 155 State: state, 156 DID: did, 157 PDS: pds, 158 AuthServer: meta.TokenEndpoint, 159 Issuer: meta.Issuer, 160 PKCEVerifier: pkceVerifier, 161 DPoPKey: dpopKey, 162 DPoPNonce: dpopNonce, 163 CreatedAt: time.Now(), 164 } 165 166 h.pendingMu.Lock() 167 h.pending[state] = pending 168 h.pendingMu.Unlock() 169 170 authURL, _ := url.Parse(meta.AuthorizationEndpoint) 171 q := authURL.Query() 172 q.Set("client_id", client.ClientID) 173 q.Set("request_uri", parResp.RequestURI) 174 authURL.RawQuery = q.Encode() 175 176 http.Redirect(w, r, authURL.String(), http.StatusFound) 177} 178 179func (h *Handler) HandleStart(w http.ResponseWriter, r *http.Request) { 180 if r.Method != "POST" { 181 http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 182 return 183 } 184 185 var req struct { 186 Handle string `json:"handle"` 187 InviteCode string `json:"invite_code"` 188 } 189 if err := json.NewDecoder(r.Body).Decode(&req); err != nil { 190 http.Error(w, "Invalid request body", http.StatusBadRequest) 191 return 192 } 193 194 if req.Handle == "" { 195 http.Error(w, "Handle is required", http.StatusBadRequest) 196 return 197 } 198 199 requiredCode := os.Getenv("INVITE_CODE") 200 if requiredCode != "" && req.InviteCode != requiredCode { 201 w.Header().Set("Content-Type", "application/json") 202 w.WriteHeader(http.StatusForbidden) 203 json.NewEncoder(w).Encode(map[string]string{ 204 "error": "Invite code required", 205 "code": "invite_required", 206 }) 207 return 208 } 209 210 client := h.getDynamicClient(r) 211 ctx := r.Context() 212 213 did, err := client.ResolveHandle(ctx, req.Handle) 214 if err != nil { 215 w.Header().Set("Content-Type", "application/json") 216 w.WriteHeader(http.StatusBadRequest) 217 json.NewEncoder(w).Encode(map[string]string{"error": "Could not find that account. Please check the handle."}) 218 return 219 } 220 221 pds, err := client.ResolveDIDToPDS(ctx, did) 222 if err != nil { 223 w.Header().Set("Content-Type", "application/json") 224 w.WriteHeader(http.StatusBadRequest) 225 json.NewEncoder(w).Encode(map[string]string{"error": "Failed to resolve PDS"}) 226 return 227 } 228 229 meta, err := client.GetAuthServerMetadata(ctx, pds) 230 if err != nil { 231 w.Header().Set("Content-Type", "application/json") 232 w.WriteHeader(http.StatusInternalServerError) 233 json.NewEncoder(w).Encode(map[string]string{"error": "Failed to get auth server"}) 234 return 235 } 236 237 dpopKey, err := client.GenerateDPoPKey() 238 if err != nil { 239 w.Header().Set("Content-Type", "application/json") 240 w.WriteHeader(http.StatusInternalServerError) 241 json.NewEncoder(w).Encode(map[string]string{"error": "Internal error"}) 242 return 243 } 244 245 pkceVerifier, pkceChallenge := client.GeneratePKCE() 246 scope := "atproto offline_access blob:* include:at.margin.authFull" 247 248 parResp, state, dpopNonce, err := client.SendPAR(meta, req.Handle, scope, dpopKey, pkceChallenge) 249 if err != nil { 250 log.Printf("PAR request failed: %v", err) 251 w.Header().Set("Content-Type", "application/json") 252 w.WriteHeader(http.StatusInternalServerError) 253 json.NewEncoder(w).Encode(map[string]string{"error": "Failed to initiate authentication"}) 254 return 255 } 256 257 pending := &PendingAuth{ 258 State: state, 259 DID: did, 260 Handle: req.Handle, 261 PDS: pds, 262 AuthServer: meta.TokenEndpoint, 263 Issuer: meta.Issuer, 264 PKCEVerifier: pkceVerifier, 265 DPoPKey: dpopKey, 266 DPoPNonce: dpopNonce, 267 CreatedAt: time.Now(), 268 } 269 270 h.pendingMu.Lock() 271 h.pending[state] = pending 272 h.pendingMu.Unlock() 273 274 authURL, _ := url.Parse(meta.AuthorizationEndpoint) 275 q := authURL.Query() 276 q.Set("client_id", client.ClientID) 277 q.Set("request_uri", parResp.RequestURI) 278 authURL.RawQuery = q.Encode() 279 280 w.Header().Set("Content-Type", "application/json") 281 json.NewEncoder(w).Encode(map[string]string{ 282 "authorizationUrl": authURL.String(), 283 }) 284} 285 286func (h *Handler) HandleSignup(w http.ResponseWriter, r *http.Request) { 287 if r.Method != "POST" { 288 http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 289 return 290 } 291 292 var req struct { 293 PdsURL string `json:"pds_url"` 294 } 295 if err := json.NewDecoder(r.Body).Decode(&req); err != nil { 296 http.Error(w, "Invalid request body", http.StatusBadRequest) 297 return 298 } 299 300 if req.PdsURL == "" { 301 http.Error(w, "PDS URL is required", http.StatusBadRequest) 302 return 303 } 304 305 client := h.getDynamicClient(r) 306 ctx := r.Context() 307 308 meta, err := client.GetAuthServerMetadataForSignup(ctx, req.PdsURL) 309 if err != nil { 310 log.Printf("Failed to get auth metadata for signup from %s: %v", req.PdsURL, err) 311 w.Header().Set("Content-Type", "application/json") 312 w.WriteHeader(http.StatusBadRequest) 313 json.NewEncoder(w).Encode(map[string]string{"error": "Failed to connect to PDS"}) 314 return 315 } 316 317 dpopKey, err := client.GenerateDPoPKey() 318 if err != nil { 319 w.Header().Set("Content-Type", "application/json") 320 w.WriteHeader(http.StatusInternalServerError) 321 json.NewEncoder(w).Encode(map[string]string{"error": "Internal error"}) 322 return 323 } 324 325 pkceVerifier, pkceChallenge := client.GeneratePKCE() 326 scope := "atproto offline_access blob:* include:at.margin.authFull" 327 328 parResp, state, dpopNonce, err := client.SendPAR(meta, "", scope, dpopKey, pkceChallenge) 329 if err != nil { 330 log.Printf("PAR request failed for signup: %v", err) 331 w.Header().Set("Content-Type", "application/json") 332 w.WriteHeader(http.StatusInternalServerError) 333 json.NewEncoder(w).Encode(map[string]string{"error": "Failed to initiate signup"}) 334 return 335 } 336 337 pending := &PendingAuth{ 338 State: state, 339 DID: "", 340 Handle: "", 341 PDS: req.PdsURL, 342 AuthServer: meta.TokenEndpoint, 343 Issuer: meta.Issuer, 344 PKCEVerifier: pkceVerifier, 345 DPoPKey: dpopKey, 346 DPoPNonce: dpopNonce, 347 CreatedAt: time.Now(), 348 } 349 350 h.pendingMu.Lock() 351 h.pending[state] = pending 352 h.pendingMu.Unlock() 353 354 authURL, _ := url.Parse(meta.AuthorizationEndpoint) 355 q := authURL.Query() 356 q.Set("client_id", client.ClientID) 357 q.Set("request_uri", parResp.RequestURI) 358 authURL.RawQuery = q.Encode() 359 360 w.Header().Set("Content-Type", "application/json") 361 json.NewEncoder(w).Encode(map[string]string{ 362 "authorizationUrl": authURL.String(), 363 }) 364} 365 366func (h *Handler) HandleCallback(w http.ResponseWriter, r *http.Request) { 367 client := h.getDynamicClient(r) 368 369 state := r.URL.Query().Get("state") 370 code := r.URL.Query().Get("code") 371 iss := r.URL.Query().Get("iss") 372 373 if state == "" || code == "" { 374 http.Error(w, "Missing state or code parameter", http.StatusBadRequest) 375 return 376 } 377 378 h.pendingMu.Lock() 379 pending, ok := h.pending[state] 380 if ok { 381 delete(h.pending, state) 382 } 383 h.pendingMu.Unlock() 384 385 if !ok { 386 http.Error(w, "Invalid or expired state", http.StatusBadRequest) 387 return 388 } 389 390 if time.Since(pending.CreatedAt) > 10*time.Minute { 391 http.Error(w, "Authentication request expired", http.StatusBadRequest) 392 return 393 } 394 395 if iss != "" && iss != pending.Issuer { 396 http.Error(w, "Issuer mismatch", http.StatusBadRequest) 397 return 398 } 399 400 ctx := r.Context() 401 meta, err := client.GetAuthServerMetadataForSignup(ctx, pending.PDS) 402 if err != nil { 403 log.Printf("Failed to get auth metadata in callback for %s: %v", pending.PDS, err) 404 http.Error(w, fmt.Sprintf("Failed to get auth metadata: %v", err), http.StatusInternalServerError) 405 return 406 } 407 408 tokenResp, newNonce, err := client.ExchangeCode(meta, code, pending.PKCEVerifier, pending.DPoPKey, pending.DPoPNonce) 409 if err != nil { 410 http.Error(w, fmt.Sprintf("Token exchange failed: %v", err), http.StatusInternalServerError) 411 return 412 } 413 414 if pending.DID != "" && tokenResp.Sub != pending.DID { 415 log.Printf("Security: OAuth sub mismatch, expected %s, got %s", pending.DID, tokenResp.Sub) 416 http.Error(w, "Account identity mismatch, authorization returned different account", http.StatusBadRequest) 417 return 418 } 419 420 _ = newNonce 421 422 sessionID := generateSessionID() 423 expiresAt := time.Now().Add(7 * 24 * time.Hour) 424 425 dpopKeyBytes, err := x509.MarshalECPrivateKey(pending.DPoPKey) 426 if err != nil { 427 http.Error(w, "Failed to marshal DPoP key", http.StatusInternalServerError) 428 return 429 } 430 dpopKeyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: dpopKeyBytes}) 431 432 err = h.db.SaveSession( 433 sessionID, 434 tokenResp.Sub, 435 pending.Handle, 436 tokenResp.AccessToken, 437 tokenResp.RefreshToken, 438 string(dpopKeyPEM), 439 expiresAt, 440 ) 441 if err != nil { 442 http.Error(w, "Failed to save session", http.StatusInternalServerError) 443 return 444 } 445 446 http.SetCookie(w, &http.Cookie{ 447 Name: "margin_session", 448 Value: sessionID, 449 Path: "/", 450 HttpOnly: true, 451 Secure: true, 452 SameSite: http.SameSiteNoneMode, 453 MaxAge: 86400 * 7, 454 }) 455 456 go h.cleanupOrphanedReplies(tokenResp.Sub, tokenResp.AccessToken, string(dpopKeyPEM), pending.PDS) 457 go func() { 458 log.Printf("Starting background sync for %s...", tokenResp.Sub) 459 _, err := h.syncService.PerformSync(context.Background(), tokenResp.Sub, func(ctx context.Context, did string) (*xrpc.Client, error) { 460 return xrpc.NewClient(pending.PDS, tokenResp.AccessToken, pending.DPoPKey), nil 461 }) 462 463 if err != nil { 464 log.Printf("Background sync failed for %s: %v", tokenResp.Sub, err) 465 } else { 466 log.Printf("Background sync completed for %s", tokenResp.Sub) 467 } 468 }() 469 470 http.Redirect(w, r, "/?logged_in=true", http.StatusFound) 471} 472 473func (h *Handler) cleanupOrphanedReplies(did, accessToken, dpopKeyPEM, pds string) { 474 orphans, err := h.db.GetOrphanedRepliesByAuthor(did) 475 if err != nil || len(orphans) == 0 { 476 return 477 } 478 479 block, _ := pem.Decode([]byte(dpopKeyPEM)) 480 if block == nil { 481 return 482 } 483 dpopKey, err := x509.ParseECPrivateKey(block.Bytes) 484 if err != nil { 485 return 486 } 487 488 for _, reply := range orphans { 489 490 parts := url.PathEscape(reply.URI) 491 _ = parts 492 uriParts := splitURI(reply.URI) 493 if len(uriParts) < 2 { 494 continue 495 } 496 rkey := uriParts[len(uriParts)-1] 497 498 deleteFromPDS(pds, accessToken, dpopKey, "at.margin.reply", did, rkey) 499 500 h.db.DeleteReply(reply.URI) 501 } 502} 503 504func splitURI(uri string) []string { 505 506 return splitBySlash(uri) 507} 508 509func splitBySlash(s string) []string { 510 var result []string 511 current := "" 512 for _, c := range s { 513 if c == '/' { 514 if current != "" { 515 result = append(result, current) 516 } 517 current = "" 518 } else { 519 current += string(c) 520 } 521 } 522 if current != "" { 523 result = append(result, current) 524 } 525 return result 526} 527 528func deleteFromPDS(pds, accessToken string, dpopKey *ecdsa.PrivateKey, collection, did, rkey string) { 529 530 client := xrpc.NewClient(pds, accessToken, dpopKey) 531 err := client.DeleteRecord(context.Background(), collection, did, rkey) 532 if err != nil { 533 log.Printf("Failed to delete orphaned reply from PDS: %v", err) 534 } else { 535 log.Printf("Cleaned up orphaned reply %s/%s from PDS", collection, rkey) 536 } 537} 538 539func (h *Handler) HandleLogout(w http.ResponseWriter, r *http.Request) { 540 cookie, err := r.Cookie("margin_session") 541 if err == nil { 542 h.db.DeleteSession(cookie.Value) 543 } 544 545 http.SetCookie(w, &http.Cookie{ 546 Name: "margin_session", 547 Value: "", 548 Path: "/", 549 HttpOnly: true, 550 MaxAge: -1, 551 }) 552 553 w.Header().Set("Content-Type", "application/json") 554 json.NewEncoder(w).Encode(map[string]bool{"success": true}) 555} 556 557func (h *Handler) HandleSession(w http.ResponseWriter, r *http.Request) { 558 cookie, err := r.Cookie("margin_session") 559 if err != nil { 560 w.Header().Set("Content-Type", "application/json") 561 json.NewEncoder(w).Encode(map[string]interface{}{"authenticated": false}) 562 return 563 } 564 565 did, handle, _, _, _, err := h.db.GetSession(cookie.Value) 566 if err != nil { 567 w.Header().Set("Content-Type", "application/json") 568 json.NewEncoder(w).Encode(map[string]interface{}{"authenticated": false}) 569 return 570 } 571 572 w.Header().Set("Content-Type", "application/json") 573 json.NewEncoder(w).Encode(map[string]interface{}{ 574 "authenticated": true, 575 "did": did, 576 "handle": handle, 577 }) 578} 579 580func (h *Handler) HandleClientMetadata(w http.ResponseWriter, r *http.Request) { 581 client := h.getDynamicClient(r) 582 baseURL := client.ClientID[:len(client.ClientID)-len("/client-metadata.json")] 583 584 w.Header().Set("Content-Type", "application/json") 585 json.NewEncoder(w).Encode(map[string]interface{}{ 586 "client_id": client.ClientID, 587 "client_name": "Margin", 588 "client_uri": baseURL, 589 "logo_uri": baseURL + "/logo.svg", 590 "tos_uri": baseURL + "/terms", 591 "policy_uri": baseURL + "/privacy", 592 "redirect_uris": []string{client.RedirectURI}, 593 "grant_types": []string{"authorization_code", "refresh_token"}, 594 "response_types": []string{"code"}, 595 "scope": "atproto offline_access blob:* include:at.margin.authFull", 596 "token_endpoint_auth_method": "private_key_jwt", 597 "token_endpoint_auth_signing_alg": "ES256", 598 "dpop_bound_access_tokens": true, 599 "jwks_uri": baseURL + "/jwks.json", 600 "application_type": "web", 601 }) 602} 603 604func (h *Handler) HandleJWKS(w http.ResponseWriter, r *http.Request) { 605 client := h.getDynamicClient(r) 606 w.Header().Set("Content-Type", "application/json") 607 json.NewEncoder(w).Encode(client.GetPublicJWKS()) 608} 609 610func (h *Handler) GetPrivateKey() *ecdsa.PrivateKey { 611 return h.privateKey 612} 613 614func generateSessionID() string { 615 b := make([]byte, 32) 616 rand.Read(b) 617 return fmt.Sprintf("%x", b) 618}