Write on the margins of the internet. Powered by the AT Protocol.
margin.at
extension
web
atproto
comments
1package oauth
2
3import (
4 "context"
5 "crypto/ecdsa"
6 "crypto/elliptic"
7 "crypto/rand"
8 "crypto/x509"
9 "encoding/json"
10 "encoding/pem"
11 "fmt"
12 "log"
13 "net/http"
14 "net/url"
15 "os"
16 "sync"
17 "time"
18
19 "margin.at/internal/db"
20 internal_sync "margin.at/internal/sync"
21 "margin.at/internal/xrpc"
22)
23
24type Handler struct {
25 db *db.DB
26 configuredBaseURL string
27 privateKey *ecdsa.PrivateKey
28 pending map[string]*PendingAuth
29 pendingMu sync.RWMutex
30 syncService *internal_sync.Service
31}
32
33func NewHandler(database *db.DB, syncService *internal_sync.Service) (*Handler, error) {
34
35 configuredBaseURL := os.Getenv("BASE_URL")
36
37 privateKey, err := loadOrGenerateKey()
38 if err != nil {
39 return nil, fmt.Errorf("failed to load/generate key: %w", err)
40 }
41
42 return &Handler{
43 db: database,
44 configuredBaseURL: configuredBaseURL,
45 privateKey: privateKey,
46 pending: make(map[string]*PendingAuth),
47 syncService: syncService,
48 }, nil
49}
50
51func loadOrGenerateKey() (*ecdsa.PrivateKey, error) {
52 keyPath := os.Getenv("OAUTH_KEY_PATH")
53 if keyPath == "" {
54 keyPath = "./oauth_private_key.pem"
55 }
56
57 if data, err := os.ReadFile(keyPath); err == nil {
58 block, _ := pem.Decode(data)
59 if block != nil {
60 key, err := x509.ParseECPrivateKey(block.Bytes)
61 if err == nil {
62 return key, nil
63 }
64 }
65 }
66
67 key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
68 if err != nil {
69 return nil, err
70 }
71
72 keyBytes, err := x509.MarshalECPrivateKey(key)
73 if err != nil {
74 return nil, err
75 }
76
77 block := &pem.Block{
78 Type: "EC PRIVATE KEY",
79 Bytes: keyBytes,
80 }
81
82 if err := os.WriteFile(keyPath, pem.EncodeToMemory(block), 0600); err != nil {
83 log.Printf("Warning: could not save key to %s: %v\n", keyPath, err)
84 }
85
86 return key, nil
87}
88
89func (h *Handler) getDynamicClient(r *http.Request) *Client {
90 baseURL := h.configuredBaseURL
91 if baseURL == "" {
92 scheme := "http"
93 if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" {
94 scheme = "https"
95 }
96 baseURL = fmt.Sprintf("%s://%s", scheme, r.Host)
97 }
98
99 if len(baseURL) > 0 && baseURL[len(baseURL)-1] == '/' {
100 baseURL = baseURL[:len(baseURL)-1]
101 }
102
103 clientID := baseURL + "/client-metadata.json"
104 redirectURI := baseURL + "/auth/callback"
105
106 return NewClient(clientID, redirectURI, h.privateKey)
107}
108
109func (h *Handler) HandleLogin(w http.ResponseWriter, r *http.Request) {
110 client := h.getDynamicClient(r)
111
112 handle := r.URL.Query().Get("handle")
113 if handle == "" {
114 http.Redirect(w, r, "/login", http.StatusFound)
115 return
116 }
117
118 ctx := r.Context()
119
120 did, err := client.ResolveHandle(ctx, handle)
121 if err != nil {
122 http.Error(w, fmt.Sprintf("Failed to resolve handle: %v", err), http.StatusBadRequest)
123 return
124 }
125
126 pds, err := client.ResolveDIDToPDS(ctx, did)
127 if err != nil {
128 http.Error(w, fmt.Sprintf("Failed to resolve PDS: %v", err), http.StatusBadRequest)
129 return
130 }
131
132 meta, err := client.GetAuthServerMetadata(ctx, pds)
133 if err != nil {
134 http.Error(w, fmt.Sprintf("Failed to get auth server metadata: %v", err), http.StatusBadRequest)
135 return
136 }
137
138 dpopKey, err := client.GenerateDPoPKey()
139 if err != nil {
140 http.Error(w, fmt.Sprintf("Failed to generate DPoP key: %v", err), http.StatusInternalServerError)
141 return
142 }
143
144 pkceVerifier, pkceChallenge := client.GeneratePKCE()
145
146 scope := "atproto offline_access blob:* include:at.margin.authFull"
147
148 parResp, state, dpopNonce, err := client.SendPAR(meta, handle, scope, dpopKey, pkceChallenge)
149 if err != nil {
150 http.Error(w, fmt.Sprintf("PAR request failed: %v", err), http.StatusInternalServerError)
151 return
152 }
153
154 pending := &PendingAuth{
155 State: state,
156 DID: did,
157 PDS: pds,
158 AuthServer: meta.TokenEndpoint,
159 Issuer: meta.Issuer,
160 PKCEVerifier: pkceVerifier,
161 DPoPKey: dpopKey,
162 DPoPNonce: dpopNonce,
163 CreatedAt: time.Now(),
164 }
165
166 h.pendingMu.Lock()
167 h.pending[state] = pending
168 h.pendingMu.Unlock()
169
170 authURL, _ := url.Parse(meta.AuthorizationEndpoint)
171 q := authURL.Query()
172 q.Set("client_id", client.ClientID)
173 q.Set("request_uri", parResp.RequestURI)
174 authURL.RawQuery = q.Encode()
175
176 http.Redirect(w, r, authURL.String(), http.StatusFound)
177}
178
179func (h *Handler) HandleStart(w http.ResponseWriter, r *http.Request) {
180 if r.Method != "POST" {
181 http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
182 return
183 }
184
185 var req struct {
186 Handle string `json:"handle"`
187 InviteCode string `json:"invite_code"`
188 }
189 if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
190 http.Error(w, "Invalid request body", http.StatusBadRequest)
191 return
192 }
193
194 if req.Handle == "" {
195 http.Error(w, "Handle is required", http.StatusBadRequest)
196 return
197 }
198
199 requiredCode := os.Getenv("INVITE_CODE")
200 if requiredCode != "" && req.InviteCode != requiredCode {
201 w.Header().Set("Content-Type", "application/json")
202 w.WriteHeader(http.StatusForbidden)
203 json.NewEncoder(w).Encode(map[string]string{
204 "error": "Invite code required",
205 "code": "invite_required",
206 })
207 return
208 }
209
210 client := h.getDynamicClient(r)
211 ctx := r.Context()
212
213 did, err := client.ResolveHandle(ctx, req.Handle)
214 if err != nil {
215 w.Header().Set("Content-Type", "application/json")
216 w.WriteHeader(http.StatusBadRequest)
217 json.NewEncoder(w).Encode(map[string]string{"error": "Could not find that account. Please check the handle."})
218 return
219 }
220
221 pds, err := client.ResolveDIDToPDS(ctx, did)
222 if err != nil {
223 w.Header().Set("Content-Type", "application/json")
224 w.WriteHeader(http.StatusBadRequest)
225 json.NewEncoder(w).Encode(map[string]string{"error": "Failed to resolve PDS"})
226 return
227 }
228
229 meta, err := client.GetAuthServerMetadata(ctx, pds)
230 if err != nil {
231 w.Header().Set("Content-Type", "application/json")
232 w.WriteHeader(http.StatusInternalServerError)
233 json.NewEncoder(w).Encode(map[string]string{"error": "Failed to get auth server"})
234 return
235 }
236
237 dpopKey, err := client.GenerateDPoPKey()
238 if err != nil {
239 w.Header().Set("Content-Type", "application/json")
240 w.WriteHeader(http.StatusInternalServerError)
241 json.NewEncoder(w).Encode(map[string]string{"error": "Internal error"})
242 return
243 }
244
245 pkceVerifier, pkceChallenge := client.GeneratePKCE()
246 scope := "atproto offline_access blob:* include:at.margin.authFull"
247
248 parResp, state, dpopNonce, err := client.SendPAR(meta, req.Handle, scope, dpopKey, pkceChallenge)
249 if err != nil {
250 log.Printf("PAR request failed: %v", err)
251 w.Header().Set("Content-Type", "application/json")
252 w.WriteHeader(http.StatusInternalServerError)
253 json.NewEncoder(w).Encode(map[string]string{"error": "Failed to initiate authentication"})
254 return
255 }
256
257 pending := &PendingAuth{
258 State: state,
259 DID: did,
260 Handle: req.Handle,
261 PDS: pds,
262 AuthServer: meta.TokenEndpoint,
263 Issuer: meta.Issuer,
264 PKCEVerifier: pkceVerifier,
265 DPoPKey: dpopKey,
266 DPoPNonce: dpopNonce,
267 CreatedAt: time.Now(),
268 }
269
270 h.pendingMu.Lock()
271 h.pending[state] = pending
272 h.pendingMu.Unlock()
273
274 authURL, _ := url.Parse(meta.AuthorizationEndpoint)
275 q := authURL.Query()
276 q.Set("client_id", client.ClientID)
277 q.Set("request_uri", parResp.RequestURI)
278 authURL.RawQuery = q.Encode()
279
280 w.Header().Set("Content-Type", "application/json")
281 json.NewEncoder(w).Encode(map[string]string{
282 "authorizationUrl": authURL.String(),
283 })
284}
285
286func (h *Handler) HandleSignup(w http.ResponseWriter, r *http.Request) {
287 if r.Method != "POST" {
288 http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
289 return
290 }
291
292 var req struct {
293 PdsURL string `json:"pds_url"`
294 }
295 if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
296 http.Error(w, "Invalid request body", http.StatusBadRequest)
297 return
298 }
299
300 if req.PdsURL == "" {
301 http.Error(w, "PDS URL is required", http.StatusBadRequest)
302 return
303 }
304
305 client := h.getDynamicClient(r)
306 ctx := r.Context()
307
308 meta, err := client.GetAuthServerMetadataForSignup(ctx, req.PdsURL)
309 if err != nil {
310 log.Printf("Failed to get auth metadata for signup from %s: %v", req.PdsURL, err)
311 w.Header().Set("Content-Type", "application/json")
312 w.WriteHeader(http.StatusBadRequest)
313 json.NewEncoder(w).Encode(map[string]string{"error": "Failed to connect to PDS"})
314 return
315 }
316
317 dpopKey, err := client.GenerateDPoPKey()
318 if err != nil {
319 w.Header().Set("Content-Type", "application/json")
320 w.WriteHeader(http.StatusInternalServerError)
321 json.NewEncoder(w).Encode(map[string]string{"error": "Internal error"})
322 return
323 }
324
325 pkceVerifier, pkceChallenge := client.GeneratePKCE()
326 scope := "atproto offline_access blob:* include:at.margin.authFull"
327
328 parResp, state, dpopNonce, err := client.SendPAR(meta, "", scope, dpopKey, pkceChallenge)
329 if err != nil {
330 log.Printf("PAR request failed for signup: %v", err)
331 w.Header().Set("Content-Type", "application/json")
332 w.WriteHeader(http.StatusInternalServerError)
333 json.NewEncoder(w).Encode(map[string]string{"error": "Failed to initiate signup"})
334 return
335 }
336
337 pending := &PendingAuth{
338 State: state,
339 DID: "",
340 Handle: "",
341 PDS: req.PdsURL,
342 AuthServer: meta.TokenEndpoint,
343 Issuer: meta.Issuer,
344 PKCEVerifier: pkceVerifier,
345 DPoPKey: dpopKey,
346 DPoPNonce: dpopNonce,
347 CreatedAt: time.Now(),
348 }
349
350 h.pendingMu.Lock()
351 h.pending[state] = pending
352 h.pendingMu.Unlock()
353
354 authURL, _ := url.Parse(meta.AuthorizationEndpoint)
355 q := authURL.Query()
356 q.Set("client_id", client.ClientID)
357 q.Set("request_uri", parResp.RequestURI)
358 authURL.RawQuery = q.Encode()
359
360 w.Header().Set("Content-Type", "application/json")
361 json.NewEncoder(w).Encode(map[string]string{
362 "authorizationUrl": authURL.String(),
363 })
364}
365
366func (h *Handler) HandleCallback(w http.ResponseWriter, r *http.Request) {
367 client := h.getDynamicClient(r)
368
369 state := r.URL.Query().Get("state")
370 code := r.URL.Query().Get("code")
371 iss := r.URL.Query().Get("iss")
372
373 if state == "" || code == "" {
374 http.Error(w, "Missing state or code parameter", http.StatusBadRequest)
375 return
376 }
377
378 h.pendingMu.Lock()
379 pending, ok := h.pending[state]
380 if ok {
381 delete(h.pending, state)
382 }
383 h.pendingMu.Unlock()
384
385 if !ok {
386 http.Error(w, "Invalid or expired state", http.StatusBadRequest)
387 return
388 }
389
390 if time.Since(pending.CreatedAt) > 10*time.Minute {
391 http.Error(w, "Authentication request expired", http.StatusBadRequest)
392 return
393 }
394
395 if iss != "" && iss != pending.Issuer {
396 http.Error(w, "Issuer mismatch", http.StatusBadRequest)
397 return
398 }
399
400 ctx := r.Context()
401 meta, err := client.GetAuthServerMetadataForSignup(ctx, pending.PDS)
402 if err != nil {
403 log.Printf("Failed to get auth metadata in callback for %s: %v", pending.PDS, err)
404 http.Error(w, fmt.Sprintf("Failed to get auth metadata: %v", err), http.StatusInternalServerError)
405 return
406 }
407
408 tokenResp, newNonce, err := client.ExchangeCode(meta, code, pending.PKCEVerifier, pending.DPoPKey, pending.DPoPNonce)
409 if err != nil {
410 http.Error(w, fmt.Sprintf("Token exchange failed: %v", err), http.StatusInternalServerError)
411 return
412 }
413
414 if pending.DID != "" && tokenResp.Sub != pending.DID {
415 log.Printf("Security: OAuth sub mismatch, expected %s, got %s", pending.DID, tokenResp.Sub)
416 http.Error(w, "Account identity mismatch, authorization returned different account", http.StatusBadRequest)
417 return
418 }
419
420 _ = newNonce
421
422 sessionID := generateSessionID()
423 expiresAt := time.Now().Add(7 * 24 * time.Hour)
424
425 dpopKeyBytes, err := x509.MarshalECPrivateKey(pending.DPoPKey)
426 if err != nil {
427 http.Error(w, "Failed to marshal DPoP key", http.StatusInternalServerError)
428 return
429 }
430 dpopKeyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: dpopKeyBytes})
431
432 err = h.db.SaveSession(
433 sessionID,
434 tokenResp.Sub,
435 pending.Handle,
436 tokenResp.AccessToken,
437 tokenResp.RefreshToken,
438 string(dpopKeyPEM),
439 expiresAt,
440 )
441 if err != nil {
442 http.Error(w, "Failed to save session", http.StatusInternalServerError)
443 return
444 }
445
446 http.SetCookie(w, &http.Cookie{
447 Name: "margin_session",
448 Value: sessionID,
449 Path: "/",
450 HttpOnly: true,
451 Secure: true,
452 SameSite: http.SameSiteNoneMode,
453 MaxAge: 86400 * 7,
454 })
455
456 go h.cleanupOrphanedReplies(tokenResp.Sub, tokenResp.AccessToken, string(dpopKeyPEM), pending.PDS)
457 go func() {
458 log.Printf("Starting background sync for %s...", tokenResp.Sub)
459 _, err := h.syncService.PerformSync(context.Background(), tokenResp.Sub, func(ctx context.Context, did string) (*xrpc.Client, error) {
460 return xrpc.NewClient(pending.PDS, tokenResp.AccessToken, pending.DPoPKey), nil
461 })
462
463 if err != nil {
464 log.Printf("Background sync failed for %s: %v", tokenResp.Sub, err)
465 } else {
466 log.Printf("Background sync completed for %s", tokenResp.Sub)
467 }
468 }()
469
470 http.Redirect(w, r, "/?logged_in=true", http.StatusFound)
471}
472
473func (h *Handler) cleanupOrphanedReplies(did, accessToken, dpopKeyPEM, pds string) {
474 orphans, err := h.db.GetOrphanedRepliesByAuthor(did)
475 if err != nil || len(orphans) == 0 {
476 return
477 }
478
479 block, _ := pem.Decode([]byte(dpopKeyPEM))
480 if block == nil {
481 return
482 }
483 dpopKey, err := x509.ParseECPrivateKey(block.Bytes)
484 if err != nil {
485 return
486 }
487
488 for _, reply := range orphans {
489
490 parts := url.PathEscape(reply.URI)
491 _ = parts
492 uriParts := splitURI(reply.URI)
493 if len(uriParts) < 2 {
494 continue
495 }
496 rkey := uriParts[len(uriParts)-1]
497
498 deleteFromPDS(pds, accessToken, dpopKey, "at.margin.reply", did, rkey)
499
500 h.db.DeleteReply(reply.URI)
501 }
502}
503
504func splitURI(uri string) []string {
505
506 return splitBySlash(uri)
507}
508
509func splitBySlash(s string) []string {
510 var result []string
511 current := ""
512 for _, c := range s {
513 if c == '/' {
514 if current != "" {
515 result = append(result, current)
516 }
517 current = ""
518 } else {
519 current += string(c)
520 }
521 }
522 if current != "" {
523 result = append(result, current)
524 }
525 return result
526}
527
528func deleteFromPDS(pds, accessToken string, dpopKey *ecdsa.PrivateKey, collection, did, rkey string) {
529
530 client := xrpc.NewClient(pds, accessToken, dpopKey)
531 err := client.DeleteRecord(context.Background(), collection, did, rkey)
532 if err != nil {
533 log.Printf("Failed to delete orphaned reply from PDS: %v", err)
534 } else {
535 log.Printf("Cleaned up orphaned reply %s/%s from PDS", collection, rkey)
536 }
537}
538
539func (h *Handler) HandleLogout(w http.ResponseWriter, r *http.Request) {
540 cookie, err := r.Cookie("margin_session")
541 if err == nil {
542 h.db.DeleteSession(cookie.Value)
543 }
544
545 http.SetCookie(w, &http.Cookie{
546 Name: "margin_session",
547 Value: "",
548 Path: "/",
549 HttpOnly: true,
550 MaxAge: -1,
551 })
552
553 w.Header().Set("Content-Type", "application/json")
554 json.NewEncoder(w).Encode(map[string]bool{"success": true})
555}
556
557func (h *Handler) HandleSession(w http.ResponseWriter, r *http.Request) {
558 cookie, err := r.Cookie("margin_session")
559 if err != nil {
560 w.Header().Set("Content-Type", "application/json")
561 json.NewEncoder(w).Encode(map[string]interface{}{"authenticated": false})
562 return
563 }
564
565 did, handle, _, _, _, err := h.db.GetSession(cookie.Value)
566 if err != nil {
567 w.Header().Set("Content-Type", "application/json")
568 json.NewEncoder(w).Encode(map[string]interface{}{"authenticated": false})
569 return
570 }
571
572 w.Header().Set("Content-Type", "application/json")
573 json.NewEncoder(w).Encode(map[string]interface{}{
574 "authenticated": true,
575 "did": did,
576 "handle": handle,
577 })
578}
579
580func (h *Handler) HandleClientMetadata(w http.ResponseWriter, r *http.Request) {
581 client := h.getDynamicClient(r)
582 baseURL := client.ClientID[:len(client.ClientID)-len("/client-metadata.json")]
583
584 w.Header().Set("Content-Type", "application/json")
585 json.NewEncoder(w).Encode(map[string]interface{}{
586 "client_id": client.ClientID,
587 "client_name": "Margin",
588 "client_uri": baseURL,
589 "logo_uri": baseURL + "/logo.svg",
590 "tos_uri": baseURL + "/terms",
591 "policy_uri": baseURL + "/privacy",
592 "redirect_uris": []string{client.RedirectURI},
593 "grant_types": []string{"authorization_code", "refresh_token"},
594 "response_types": []string{"code"},
595 "scope": "atproto offline_access blob:* include:at.margin.authFull",
596 "token_endpoint_auth_method": "private_key_jwt",
597 "token_endpoint_auth_signing_alg": "ES256",
598 "dpop_bound_access_tokens": true,
599 "jwks_uri": baseURL + "/jwks.json",
600 "application_type": "web",
601 })
602}
603
604func (h *Handler) HandleJWKS(w http.ResponseWriter, r *http.Request) {
605 client := h.getDynamicClient(r)
606 w.Header().Set("Content-Type", "application/json")
607 json.NewEncoder(w).Encode(client.GetPublicJWKS())
608}
609
610func (h *Handler) GetPrivateKey() *ecdsa.PrivateKey {
611 return h.privateKey
612}
613
614func generateSessionID() string {
615 b := make([]byte, 32)
616 rand.Read(b)
617 return fmt.Sprintf("%x", b)
618}