Yōten: A social tracker for your language learning journey built on the atproto.

feat: use redis to store oauth sessions and switch to new indigo oauth library

Signed-off-by: brookjeynes <me@brookjeynes.dev>

brookjeynes.dev bd2ffaa0 ded4764a

verified
+762 -969
-55
internal/atproto/xrpc.go
··· 1 - package atproto 2 - 3 - import ( 4 - "context" 5 - 6 - "github.com/bluesky-social/indigo/api/atproto" 7 - "github.com/bluesky-social/indigo/xrpc" 8 - oauth "tangled.sh/icyphox.sh/atproto-oauth" 9 - ) 10 - 11 - type Client struct { 12 - *oauth.XrpcClient 13 - authArgs *oauth.XrpcAuthedRequestArgs 14 - } 15 - 16 - func NewClient(client *oauth.XrpcClient, authArgs *oauth.XrpcAuthedRequestArgs) *Client { 17 - return &Client{ 18 - XrpcClient: client, 19 - authArgs: authArgs, 20 - } 21 - } 22 - 23 - func (c *Client) RepoPutRecord(ctx context.Context, input *atproto.RepoPutRecord_Input) (*atproto.RepoPutRecord_Output, error) { 24 - var out atproto.RepoPutRecord_Output 25 - if err := c.Do(ctx, c.authArgs, xrpc.Procedure, "application/json", "com.atproto.repo.putRecord", nil, input, &out); err != nil { 26 - return nil, err 27 - } 28 - 29 - return &out, nil 30 - } 31 - 32 - func (c *Client) RepoGetRecord(ctx context.Context, cid string, collection string, repo string, rkey string) (*atproto.RepoGetRecord_Output, error) { 33 - var out atproto.RepoGetRecord_Output 34 - 35 - params := map[string]any{ 36 - "cid": cid, 37 - "collection": collection, 38 - "repo": repo, 39 - "rkey": rkey, 40 - } 41 - if err := c.Do(ctx, c.authArgs, xrpc.Query, "", "com.atproto.repo.getRecord", params, nil, &out); err != nil { 42 - return nil, err 43 - } 44 - 45 - return &out, nil 46 - } 47 - 48 - func (c *Client) RepoDeleteRecord(ctx context.Context, input *atproto.RepoDeleteRecord_Input) (*atproto.RepoDeleteRecord_Output, error) { 49 - var out atproto.RepoDeleteRecord_Output 50 - if err := c.Do(ctx, c.authArgs, xrpc.Procedure, "application/json", "com.atproto.repo.deleteRecord", nil, input, &out); err != nil { 51 - return nil, err 52 - } 53 - 54 - return &out, nil 55 - }
+14
internal/cache/cache.go
··· 1 + package cache 2 + 3 + import "github.com/redis/go-redis/v9" 4 + 5 + type Cache struct { 6 + *redis.Client 7 + } 8 + 9 + func New(addr string) *Cache { 10 + rdb := redis.NewClient(&redis.Options{ 11 + Addr: addr, 12 + }) 13 + return &Cache{rdb} 14 + }
+172
internal/cache/session/store.go
··· 1 + package session 2 + 3 + import ( 4 + "context" 5 + "encoding/json" 6 + "fmt" 7 + "time" 8 + 9 + "yoten.app/internal/cache" 10 + ) 11 + 12 + type OAuthSession struct { 13 + Handle string 14 + Did string 15 + PdsUrl string 16 + AccessJwt string 17 + RefreshJwt string 18 + AuthServerIss string 19 + DpopPdsNonce string 20 + DpopAuthserverNonce string 21 + DpopPrivateJwk string 22 + Expiry string 23 + } 24 + 25 + type OAuthRequest struct { 26 + AuthserverIss string 27 + Handle string 28 + State string 29 + Did string 30 + PdsUrl string 31 + PkceVerifier string 32 + DpopAuthserverNonce string 33 + DpopPrivateJwk string 34 + ReturnUrl string 35 + } 36 + 37 + type SessionStore struct { 38 + cache *cache.Cache 39 + } 40 + 41 + const ( 42 + stateKey = "oauthstate:%s" 43 + requestKey = "oauthrequest:%s" 44 + sessionKey = "oauthsession:%s" 45 + ) 46 + 47 + func New(cache *cache.Cache) *SessionStore { 48 + return &SessionStore{cache: cache} 49 + } 50 + 51 + func (s *SessionStore) SaveSession(ctx context.Context, session OAuthSession) error { 52 + key := fmt.Sprintf(sessionKey, session.Did) 53 + data, err := json.Marshal(session) 54 + if err != nil { 55 + return err 56 + } 57 + 58 + // set with ttl (7 days) 59 + ttl := 7 * 24 * time.Hour 60 + 61 + return s.cache.Set(ctx, key, data, ttl).Err() 62 + } 63 + 64 + // SaveRequest stores the OAuth request to be later fetched in the callback. Since 65 + // the fetching happens by comparing the state we get in the callback params, we 66 + // store an additional state->did mapping which then lets us fetch the whole OAuth request. 67 + func (s *SessionStore) SaveRequest(ctx context.Context, request OAuthRequest) error { 68 + key := fmt.Sprintf(requestKey, request.Did) 69 + data, err := json.Marshal(request) 70 + if err != nil { 71 + return err 72 + } 73 + 74 + // oauth flow must complete within 30 minutes 75 + err = s.cache.Set(ctx, key, data, 30*time.Minute).Err() 76 + if err != nil { 77 + return fmt.Errorf("error saving request: %w", err) 78 + } 79 + 80 + stateKey := fmt.Sprintf(stateKey, request.State) 81 + err = s.cache.Set(ctx, stateKey, request.Did, 30*time.Minute).Err() 82 + if err != nil { 83 + return fmt.Errorf("error saving state->did mapping: %w", err) 84 + } 85 + 86 + return nil 87 + } 88 + 89 + func (s *SessionStore) GetSession(ctx context.Context, did string) (*OAuthSession, error) { 90 + key := fmt.Sprintf(sessionKey, did) 91 + val, err := s.cache.Get(ctx, key).Result() 92 + if err != nil { 93 + return nil, err 94 + } 95 + 96 + var session OAuthSession 97 + err = json.Unmarshal([]byte(val), &session) 98 + if err != nil { 99 + return nil, err 100 + } 101 + return &session, nil 102 + } 103 + 104 + func (s *SessionStore) GetRequestByState(ctx context.Context, state string) (*OAuthRequest, error) { 105 + didKey, err := s.getRequestKeyFromState(ctx, state) 106 + if err != nil { 107 + return nil, err 108 + } 109 + 110 + val, err := s.cache.Get(ctx, didKey).Result() 111 + if err != nil { 112 + return nil, err 113 + } 114 + 115 + var request OAuthRequest 116 + err = json.Unmarshal([]byte(val), &request) 117 + if err != nil { 118 + return nil, err 119 + } 120 + 121 + return &request, nil 122 + } 123 + 124 + func (s *SessionStore) DeleteSession(ctx context.Context, did string) error { 125 + key := fmt.Sprintf(sessionKey, did) 126 + return s.cache.Del(ctx, key).Err() 127 + } 128 + 129 + func (s *SessionStore) DeleteRequestByState(ctx context.Context, state string) error { 130 + didKey, err := s.getRequestKeyFromState(ctx, state) 131 + if err != nil { 132 + return err 133 + } 134 + 135 + err = s.cache.Del(ctx, fmt.Sprintf(stateKey, state)).Err() 136 + if err != nil { 137 + return err 138 + } 139 + 140 + return s.cache.Del(ctx, didKey).Err() 141 + } 142 + 143 + func (s *SessionStore) RefreshSession(ctx context.Context, did, access, refresh, expiry string) error { 144 + session, err := s.GetSession(ctx, did) 145 + if err != nil { 146 + return err 147 + } 148 + session.AccessJwt = access 149 + session.RefreshJwt = refresh 150 + session.Expiry = expiry 151 + return s.SaveSession(ctx, *session) 152 + } 153 + 154 + func (s *SessionStore) UpdateNonce(ctx context.Context, did, nonce string) error { 155 + session, err := s.GetSession(ctx, did) 156 + if err != nil { 157 + return err 158 + } 159 + session.DpopAuthserverNonce = nonce 160 + return s.SaveSession(ctx, *session) 161 + } 162 + 163 + func (s *SessionStore) getRequestKeyFromState(ctx context.Context, state string) (string, error) { 164 + key := fmt.Sprintf(stateKey, state) 165 + did, err := s.cache.Get(ctx, key).Result() 166 + if err != nil { 167 + return "", err 168 + } 169 + 170 + didKey := fmt.Sprintf(requestKey, did) 171 + return didKey, nil 172 + }
+16 -34
internal/db/db.go
··· 4 4 "context" 5 5 "database/sql" 6 6 "fmt" 7 + "strings" 7 8 8 9 _ "github.com/mattn/go-sqlite3" 9 10 ) ··· 24 25 } 25 26 26 27 func Make(dbPath string) (*DB, error) { 27 - db, err := sql.Open("sqlite3", dbPath) 28 + opts := []string{ 29 + "_foreign_keys=1", 30 + "_journal_mode=WAL", 31 + "_synchronous=NORMAL", 32 + "_auto_vacuum=incremental", 33 + } 34 + 35 + db, err := sql.Open("sqlite3", dbPath+"?"+strings.Join(opts, "&")) 28 36 if err != nil { 29 37 return nil, fmt.Errorf("failed to open db: %w", err) 30 38 } 31 - _, err = db.Exec(` 32 - pragma journal_mode = WAL; 33 - pragma synchronous = normal; 34 - pragma foreign_keys = on; 35 - pragma temp_store = memory; 36 - pragma mmap_size = 30000000000; 37 - pragma page_size = 32768; 38 - pragma auto_vacuum = incremental; 39 - pragma busy_timeout = 5000; 40 39 41 - create table if not exists oauth_requests ( 42 - id integer primary key autoincrement, 43 - auth_server_iss text not null, 44 - state text not null, 45 - did text not null, 46 - handle text not null, 47 - pds_url text not null, 48 - pkce_verifier text not null, 49 - dpop_auth_server_nonce text not null, 50 - dpop_private_jwk text not null 51 - ); 40 + ctx := context.Background() 52 41 53 - create table if not exists oauth_sessions ( 54 - id integer primary key autoincrement, 55 - did text not null, 56 - handle text not null, 57 - pds_url text not null, 58 - auth_server_iss text not null, 59 - access_jwt text not null, 60 - refresh_jwt text not null, 61 - dpop_pds_nonce text, 62 - dpop_auth_server_nonce text not null, 63 - dpop_private_jwk text not null, 64 - expiry text not null 65 - ); 42 + conn, err := db.Conn(ctx) 43 + if err != nil { 44 + return nil, err 45 + } 46 + defer conn.Close() 66 47 48 + _, err = conn.ExecContext(ctx, ` 67 49 create table if not exists profiles ( 68 50 -- id 69 51 id integer primary key autoincrement,
-173
internal/db/oauth.go
··· 1 - package db 2 - 3 - type OAuthRequest struct { 4 - ID uint 5 - AuthserverIss string 6 - Handle string 7 - State string 8 - Did string 9 - PdsUrl string 10 - PkceVerifier string 11 - DpopAuthserverNonce string 12 - DpopPrivateJwk string 13 - } 14 - 15 - type OAuthSession struct { 16 - ID uint 17 - Handle string 18 - Did string 19 - PdsUrl string 20 - AccessJwt string 21 - RefreshJwt string 22 - AuthServerIss string 23 - DpopPdsNonce string 24 - DpopAuthserverNonce string 25 - DpopPrivateJwk string 26 - Expiry string 27 - } 28 - 29 - func SaveOAuthRequest(e Execer, oauthRequest OAuthRequest) error { 30 - _, err := e.Exec(` 31 - insert into oauth_requests ( 32 - auth_server_iss, 33 - state, 34 - handle, 35 - did, 36 - pds_url, 37 - pkce_verifier, 38 - dpop_auth_server_nonce, 39 - dpop_private_jwk 40 - ) values (?, ?, ?, ?, ?, ?, ?, ?)`, 41 - oauthRequest.AuthserverIss, 42 - oauthRequest.State, 43 - oauthRequest.Handle, 44 - oauthRequest.Did, 45 - oauthRequest.PdsUrl, 46 - oauthRequest.PkceVerifier, 47 - oauthRequest.DpopAuthserverNonce, 48 - oauthRequest.DpopPrivateJwk, 49 - ) 50 - return err 51 - } 52 - 53 - func GetOAuthRequestByState(e Execer, state string) (OAuthRequest, error) { 54 - var req OAuthRequest 55 - err := e.QueryRow(` 56 - select 57 - id, 58 - auth_server_iss, 59 - handle, 60 - state, 61 - did, 62 - pds_url, 63 - pkce_verifier, 64 - dpop_auth_server_nonce, 65 - dpop_private_jwk 66 - from oauth_requests 67 - where state = ?`, state).Scan( 68 - &req.ID, 69 - &req.AuthserverIss, 70 - &req.Handle, 71 - &req.State, 72 - &req.Did, 73 - &req.PdsUrl, 74 - &req.PkceVerifier, 75 - &req.DpopAuthserverNonce, 76 - &req.DpopPrivateJwk, 77 - ) 78 - return req, err 79 - } 80 - 81 - func DeleteOAuthRequestByState(e Execer, state string) error { 82 - _, err := e.Exec(` 83 - delete from oauth_requests 84 - where state = ?`, state) 85 - return err 86 - } 87 - 88 - func SaveOAuthSession(e Execer, session OAuthSession) error { 89 - _, err := e.Exec(` 90 - insert into oauth_sessions ( 91 - did, 92 - handle, 93 - pds_url, 94 - access_jwt, 95 - refresh_jwt, 96 - auth_server_iss, 97 - dpop_auth_server_nonce, 98 - dpop_private_jwk, 99 - expiry 100 - ) values (?, ?, ?, ?, ?, ?, ?, ?, ?)`, 101 - session.Did, 102 - session.Handle, 103 - session.PdsUrl, 104 - session.AccessJwt, 105 - session.RefreshJwt, 106 - session.AuthServerIss, 107 - session.DpopAuthserverNonce, 108 - session.DpopPrivateJwk, 109 - session.Expiry, 110 - ) 111 - return err 112 - } 113 - 114 - func RefreshOAuthSession(e Execer, did string, accessJwt, refreshJwt, expiry string) error { 115 - _, err := e.Exec(` 116 - update oauth_sessions 117 - set access_jwt = ?, refresh_jwt = ?, expiry = ? 118 - where did = ?`, 119 - accessJwt, 120 - refreshJwt, 121 - expiry, 122 - did, 123 - ) 124 - return err 125 - } 126 - 127 - func GetOAuthSessionByDid(e Execer, did string) (*OAuthSession, error) { 128 - var session OAuthSession 129 - err := e.QueryRow(` 130 - select 131 - id, 132 - did, 133 - handle, 134 - pds_url, 135 - access_jwt, 136 - refresh_jwt, 137 - auth_server_iss, 138 - dpop_auth_server_nonce, 139 - dpop_private_jwk, 140 - expiry 141 - from oauth_sessions 142 - where did = ?`, did).Scan( 143 - &session.ID, 144 - &session.Did, 145 - &session.Handle, 146 - &session.PdsUrl, 147 - &session.AccessJwt, 148 - &session.RefreshJwt, 149 - &session.AuthServerIss, 150 - &session.DpopAuthserverNonce, 151 - &session.DpopPrivateJwk, 152 - &session.Expiry, 153 - ) 154 - return &session, err 155 - } 156 - 157 - func DeleteOAuthSessionByDid(e Execer, did string) error { 158 - _, err := e.Exec(` 159 - delete from oauth_sessions 160 - where did = ?`, did) 161 - return err 162 - } 163 - 164 - func UpdateDpopPdsNonce(e Execer, did string, dpopPdsNonce string) error { 165 - _, err := e.Exec(` 166 - update oauth_sessions 167 - set dpop_pds_nonce = ? 168 - where did = ?`, 169 - dpopPdsNonce, 170 - did, 171 - ) 172 - return err 173 - }
+6 -1
internal/server/app.go
··· 13 13 14 14 "yoten.app/api/yoten" 15 15 "yoten.app/internal/atproto" 16 + "yoten.app/internal/cache" 17 + "yoten.app/internal/cache/session" 16 18 "yoten.app/internal/clients/bsky" 17 19 "yoten.app/internal/consumer" 18 20 "yoten.app/internal/db" ··· 50 52 return nil, err 51 53 } 52 54 53 - oauth := oauth.NewOAuth(d, config) 55 + oauth, err := oauth.New(config) 56 + if err != nil { 57 + return nil, fmt.Errorf("failed to start oauth handler: %w", err) 58 + } 54 59 55 60 idResolver := atproto.DefaultResolver() 56 61
+7 -7
internal/server/handlers/activity.go
··· 69 69 SortedCategories: h.ComputedData.SortedCategories, 70 70 }).Render(r.Context(), w) 71 71 case http.MethodPost: 72 - client, err := h.Oauth.AuthorizedClient(r, w) 72 + client, err := h.Oauth.AuthorizedClient(r) 73 73 if err != nil { 74 74 log.Println("failed to get authorized client:", err) 75 75 htmx.HxRedirect(w, "/login") ··· 108 108 categoriesString = append(categoriesString, c.Name) 109 109 } 110 110 111 - _, err = client.RepoPutRecord(r.Context(), &comatproto.RepoPutRecord_Input{ 111 + _, err = comatproto.RepoPutRecord(r.Context(), client, &comatproto.RepoPutRecord_Input{ 112 112 Collection: yoten.ActivityDefNSID, 113 113 Repo: user.Did, 114 114 Rkey: newActivity.Rkey, ··· 160 160 htmx.HxRedirect(w, "/login") 161 161 return 162 162 } 163 - client, err := h.Oauth.AuthorizedClient(r, w) 163 + client, err := h.Oauth.AuthorizedClient(r) 164 164 if err != nil { 165 165 log.Println("failed to get authorized client:", err) 166 166 htmx.HxError(w, http.StatusUnauthorized, "Failed to delete activity, try again later.") ··· 183 183 return 184 184 } 185 185 186 - _, err = client.RepoDeleteRecord(r.Context(), &comatproto.RepoDeleteRecord_Input{ 186 + _, err = comatproto.RepoDeleteRecord(r.Context(), client, &comatproto.RepoDeleteRecord_Input{ 187 187 Collection: yoten.ActivityDefNSID, 188 188 Repo: user.Did, 189 189 Rkey: activity.Rkey, ··· 245 245 SortedCategories: h.ComputedData.SortedCategories, 246 246 }).Render(r.Context(), w) 247 247 case http.MethodPost: 248 - client, err := h.Oauth.AuthorizedClient(r, w) 248 + client, err := h.Oauth.AuthorizedClient(r) 249 249 if err != nil { 250 250 log.Println("failed to get authorized client:", err) 251 251 htmx.HxRedirect(w, "/login") ··· 279 279 return 280 280 } 281 281 282 - ex, _ := client.RepoGetRecord(r.Context(), "", yoten.ActivityDefNSID, user.Did, updatedActivity.Rkey) 282 + ex, _ := comatproto.RepoGetRecord(r.Context(), client, "", yoten.ActivityDefNSID, user.Did, updatedActivity.Rkey) 283 283 var cid *string 284 284 if ex != nil { 285 285 cid = ex.Cid ··· 290 290 categoriesString = append(categoriesString, c.Name) 291 291 } 292 292 293 - _, err = client.RepoPutRecord(r.Context(), &comatproto.RepoPutRecord_Input{ 293 + _, err = comatproto.RepoPutRecord(r.Context(), client, &comatproto.RepoPutRecord_Input{ 294 294 Collection: yoten.ActivityDefNSID, 295 295 Repo: user.Did, 296 296 Rkey: updatedActivity.Rkey,
+7 -7
internal/server/handlers/comment.go
··· 23 23 ) 24 24 25 25 func (h *Handler) HandleNewComment(w http.ResponseWriter, r *http.Request) { 26 - client, err := h.Oauth.AuthorizedClient(r, w) 26 + client, err := h.Oauth.AuthorizedClient(r) 27 27 if err != nil { 28 28 log.Println("failed to get authorized client:", err) 29 29 htmx.HxRedirect(w, "/login") ··· 85 85 CreatedAt: time.Now(), 86 86 } 87 87 88 - _, err = client.RepoPutRecord(r.Context(), &comatproto.RepoPutRecord_Input{ 88 + _, err = comatproto.RepoPutRecord(r.Context(), client, &comatproto.RepoPutRecord_Input{ 89 89 Collection: yoten.FeedCommentNSID, 90 90 Repo: newComment.Did, 91 91 Rkey: newComment.Rkey, ··· 159 159 htmx.HxRedirect(w, "/login") 160 160 return 161 161 } 162 - client, err := h.Oauth.AuthorizedClient(r, w) 162 + client, err := h.Oauth.AuthorizedClient(r) 163 163 if err != nil { 164 164 log.Println("failed to get authorized client:", err) 165 165 htmx.HxRedirect(w, "/login") ··· 182 182 return 183 183 } 184 184 185 - _, err = client.RepoDeleteRecord(r.Context(), &comatproto.RepoDeleteRecord_Input{ 185 + _, err = comatproto.RepoDeleteRecord(r.Context(), client, &comatproto.RepoDeleteRecord_Input{ 186 186 Collection: yoten.FeedCommentNSID, 187 187 Repo: user.Did, 188 188 Rkey: comment.Rkey, ··· 243 243 case http.MethodGet: 244 244 partials.EditComment(partials.EditCommentProps{Comment: comment}).Render(r.Context(), w) 245 245 case http.MethodPost: 246 - client, err := h.Oauth.AuthorizedClient(r, w) 246 + client, err := h.Oauth.AuthorizedClient(r) 247 247 if err != nil { 248 248 log.Println("failed to get authorized client:", err) 249 249 htmx.HxRedirect(w, "/login") ··· 281 281 } 282 282 } 283 283 284 - ex, _ := client.RepoGetRecord(r.Context(), "", yoten.FeedCommentNSID, user.Did, updatedComment.Rkey) 284 + ex, _ := comatproto.RepoGetRecord(r.Context(), client, "", yoten.FeedCommentNSID, user.Did, updatedComment.Rkey) 285 285 var cid *string 286 286 if ex != nil { 287 287 cid = ex.Cid 288 288 } 289 289 290 - _, err = client.RepoPutRecord(r.Context(), &comatproto.RepoPutRecord_Input{ 290 + _, err = comatproto.RepoPutRecord(r.Context(), client, &comatproto.RepoPutRecord_Input{ 291 291 Collection: yoten.FeedCommentNSID, 292 292 Repo: updatedComment.Did, 293 293 Rkey: updatedComment.Rkey,
+3 -3
internal/server/handlers/follow.go
··· 20 20 ) 21 21 22 22 func (h *Handler) HandleFollow(w http.ResponseWriter, r *http.Request) { 23 - client, err := h.Oauth.AuthorizedClient(r, w) 23 + client, err := h.Oauth.AuthorizedClient(r) 24 24 if err != nil { 25 25 log.Println("failed to get authorized client:", err) 26 26 htmx.HxRedirect(w, "/login") ··· 57 57 case http.MethodPost: 58 58 createdAt := time.Now().Format(time.RFC3339) 59 59 rkey := atproto.TID() 60 - _, err = client.RepoPutRecord(r.Context(), &comatproto.RepoPutRecord_Input{ 60 + _, err = comatproto.RepoPutRecord(r.Context(), client, &comatproto.RepoPutRecord_Input{ 61 61 Collection: yoten.GraphFollowNSID, 62 62 Repo: user.Did, 63 63 Rkey: rkey, ··· 100 100 return 101 101 } 102 102 103 - _, err = client.RepoDeleteRecord(r.Context(), &comatproto.RepoDeleteRecord_Input{ 103 + _, err = comatproto.RepoDeleteRecord(r.Context(), client, &comatproto.RepoDeleteRecord_Input{ 104 104 Collection: yoten.GraphFollowNSID, 105 105 Repo: user.Did, 106 106 Rkey: follow.Rkey,
+112
internal/server/handlers/login.go
··· 1 + package handlers 2 + 3 + import ( 4 + "fmt" 5 + "log" 6 + "net/http" 7 + "strings" 8 + 9 + "github.com/posthog/posthog-go" 10 + 11 + "yoten.app/internal/clients/bsky" 12 + ph "yoten.app/internal/clients/posthog" 13 + "yoten.app/internal/server/htmx" 14 + "yoten.app/internal/server/views" 15 + "yoten.app/internal/types" 16 + ) 17 + 18 + func (h *Handler) Login(w http.ResponseWriter, r *http.Request) { 19 + switch r.Method { 20 + case http.MethodGet: 21 + var user *types.User 22 + oauth := h.Oauth.GetUser(r) 23 + if oauth != nil { 24 + bskyProfile, err := bsky.GetBskyProfile(oauth.Did) 25 + if err != nil { 26 + log.Println("failed to get bsky profile:", err) 27 + } 28 + user = &types.User{ 29 + OauthUser: *oauth, 30 + BskyProfile: bskyProfile, 31 + } 32 + } 33 + 34 + returnURL := r.URL.Query().Get("return_url") 35 + views.LoginPage(views.LoginPageParams{ 36 + User: user, 37 + ReturnUrl: returnURL, 38 + }).Render(r.Context(), w) 39 + case http.MethodPost: 40 + handle := r.FormValue("handle") 41 + 42 + // When users copy their handle from bsky.app, it tends to have these 43 + // characters around it: 44 + // 45 + // @nelind.dk: 46 + // \u202a ensures that the handle is always rendered left to right and 47 + // \u202c reverts that so the rest of the page renders however it should 48 + handle = strings.TrimPrefix(handle, "\u202a") 49 + handle = strings.TrimSuffix(handle, "\u202c") 50 + 51 + // `@` is harmless 52 + handle = strings.TrimPrefix(handle, "@") 53 + 54 + // Basic handle validation 55 + if !strings.Contains(handle, ".") { 56 + log.Println("invalid handle format:", handle) 57 + htmx.HxError(w, http.StatusBadGateway, fmt.Sprintf("'%s' is an invalid handle. Did you mean %s.bsky.social?", handle, handle)) 58 + return 59 + } 60 + 61 + if !h.Config.Core.Dev { 62 + err := h.Posthog.Enqueue(posthog.Capture{ 63 + DistinctId: handle, 64 + Event: ph.UserSignInInitiatedEvent, 65 + }) 66 + if err != nil { 67 + log.Println("failed to enqueue posthog event:", err) 68 + } 69 + } 70 + 71 + redirectURL, err := h.Oauth.ClientApp.StartAuthFlow(r.Context(), handle) 72 + if err != nil { 73 + http.Error(w, err.Error(), http.StatusInternalServerError) 74 + return 75 + } 76 + 77 + if !h.Config.Core.Dev { 78 + err := h.Posthog.Enqueue(posthog.Capture{ 79 + DistinctId: handle, 80 + Event: ph.UserSignInSuccessEvent, 81 + }) 82 + if err != nil { 83 + log.Println("failed to enqueue posthog event:", err) 84 + } 85 + } 86 + 87 + htmx.HxRedirect(w, redirectURL) 88 + } 89 + } 90 + 91 + func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) { 92 + did := h.Oauth.GetDid(r) 93 + 94 + err := h.Oauth.DeleteSession(w, r) 95 + if err != nil { 96 + log.Println("failed to logout", "err", err) 97 + } else { 98 + log.Println("logged out successfully") 99 + } 100 + 101 + if !h.Config.Core.Dev && did != "" { 102 + err := h.Posthog.Enqueue(posthog.Capture{ 103 + DistinctId: did, 104 + Event: ph.UserLoggedOutEvent, 105 + }) 106 + if err != nil { 107 + log.Println("failed to enqueue posthog event:", err) 108 + } 109 + } 110 + 111 + htmx.HxRedirect(w, "/login") 112 + }
+4 -4
internal/server/handlers/pending-ops.go
··· 14 14 } 15 15 16 16 func ApplyPendingChanges[T Rkeyer](h *Handler, w http.ResponseWriter, r *http.Request, items []T, createKey, updateKey, deleteKey string) ([]T, error) { 17 - yotenSession, err := h.Oauth.Store.Get(r, "yoten-session") 17 + yotenSession, err := h.Oauth.SessionStore.Get(r, "yoten-session") 18 18 if err != nil { 19 19 return items, err 20 20 } ··· 74 74 } 75 75 76 76 func SavePendingCreate[T any](h *Handler, w http.ResponseWriter, r *http.Request, sessionKey string, item T) error { 77 - yotenSession, err := h.Oauth.Store.Get(r, "yoten-session") 77 + yotenSession, err := h.Oauth.SessionStore.Get(r, "yoten-session") 78 78 if err != nil { 79 79 return fmt.Errorf("failed to get yoten-session for pending create: %w", err) 80 80 } ··· 93 93 } 94 94 95 95 func SavePendingUpdate[T Rkeyer](h *Handler, w http.ResponseWriter, r *http.Request, sessionKey string, item T) error { 96 - yotenSession, err := h.Oauth.Store.Get(r, "yoten-session") 96 + yotenSession, err := h.Oauth.SessionStore.Get(r, "yoten-session") 97 97 if err != nil { 98 98 return fmt.Errorf("failed to get yoten-session for pending update: %w", err) 99 99 } ··· 118 118 } 119 119 120 120 func SavePendingDelete[T Rkeyer](h *Handler, w http.ResponseWriter, r *http.Request, sessionKey string, item T) error { 121 - yotenSession, err := h.Oauth.Store.Get(r, "yoten-session") 121 + yotenSession, err := h.Oauth.SessionStore.Get(r, "yoten-session") 122 122 if err != nil { 123 123 return fmt.Errorf("failed to get yoten-session for pending delete: %w", err) 124 124 }
+5 -5
internal/server/handlers/profile.go
··· 212 212 InitialSelectedLanguages: profileLanguageCodes, 213 213 }).Render(r.Context(), w) 214 214 case http.MethodPost: 215 - client, err := h.Oauth.AuthorizedClient(r, w) 215 + client, err := h.Oauth.AuthorizedClient(r) 216 216 if err != nil { 217 217 log.Println("failed to get authorized client:", err) 218 218 htmx.HxRedirect(w, "/login") ··· 232 232 updatedProfile.Level = profile.Level 233 233 updatedProfile.Xp = profile.Xp 234 234 if updatedProfile.DisplayName == "" { 235 - updatedProfile.DisplayName = user.Handle 235 + updatedProfile.DisplayName = user.BskyProfile.Handle 236 236 } 237 237 238 238 if err := db.ValidateProfile(updatedProfile); err != nil { ··· 254 254 return 255 255 } 256 256 257 - ex, _ := client.RepoGetRecord(r.Context(), "", yoten.ActorProfileNSID, user.Did, "self") 257 + ex, _ := comatproto.RepoGetRecord(r.Context(), client, "", yoten.ActorProfileNSID, user.Did, "self") 258 258 var cid *string 259 259 if ex != nil { 260 260 cid = ex.Cid ··· 265 265 languagesStr = append(languagesStr, string(lc.Code)) 266 266 } 267 267 268 - _, err = client.RepoPutRecord(r.Context(), &comatproto.RepoPutRecord_Input{ 268 + _, err = comatproto.RepoPutRecord(r.Context(), client, &comatproto.RepoPutRecord_Input{ 269 269 Collection: yoten.ActorProfileNSID, 270 270 Repo: user.Did, 271 271 Rkey: "self", ··· 299 299 Set("language_count", len(updatedProfile.Languages)). 300 300 Set("$set_once", posthog.NewProperties(). 301 301 Set("initial_did", user.Did). 302 - Set("initial_handle", user.Handle). 302 + Set("initial_handle", user.BskyProfile.Handle). 303 303 Set("created_at", updatedProfile.CreatedAt.Format(time.RFC3339)), 304 304 ) 305 305
+3 -3
internal/server/handlers/reaction.go
··· 21 21 ) 22 22 23 23 func (h *Handler) HandleReaction(w http.ResponseWriter, r *http.Request) { 24 - client, err := h.Oauth.AuthorizedClient(r, w) 24 + client, err := h.Oauth.AuthorizedClient(r) 25 25 if err != nil { 26 26 log.Println("failed to get authorized client:", err) 27 27 htmx.HxRedirect(w, "/login") ··· 101 101 102 102 createdAt := time.Now().Format(time.RFC3339) 103 103 rkey := atproto.TID() 104 - _, err = client.RepoPutRecord(r.Context(), &comatproto.RepoPutRecord_Input{ 104 + _, err = comatproto.RepoPutRecord(r.Context(), client, &comatproto.RepoPutRecord_Input{ 105 105 Collection: yoten.FeedReactionNSID, 106 106 Repo: user.Did, 107 107 Rkey: rkey, ··· 158 158 return 159 159 } 160 160 161 - _, err = client.RepoDeleteRecord(r.Context(), &comatproto.RepoDeleteRecord_Input{ 161 + _, err = comatproto.RepoDeleteRecord(r.Context(), client, &comatproto.RepoDeleteRecord_Input{ 162 162 Collection: yoten.FeedReactionNSID, 163 163 Repo: user.Did, 164 164 Rkey: reactionEvent.Rkey,
+7 -7
internal/server/handlers/resource.go
··· 77 77 SortedResourceTypes: h.ComputedData.SortedResourceTypes, 78 78 }).Render(r.Context(), w) 79 79 case http.MethodPost: 80 - client, err := h.Oauth.AuthorizedClient(r, w) 80 + client, err := h.Oauth.AuthorizedClient(r) 81 81 if err != nil { 82 82 log.Println("failed to get authorized client:", err) 83 83 htmx.HxRedirect(w, "/login") ··· 177 177 feedResource.Link = newResource.Link 178 178 } 179 179 180 - _, err = client.RepoPutRecord(r.Context(), &comatproto.RepoPutRecord_Input{ 180 + _, err = comatproto.RepoPutRecord(r.Context(), client, &comatproto.RepoPutRecord_Input{ 181 181 Collection: yoten.FeedResourceNSID, 182 182 Repo: user.Did, 183 183 Rkey: newResource.Rkey, ··· 224 224 htmx.HxRedirect(w, "/login") 225 225 return 226 226 } 227 - client, err := h.Oauth.AuthorizedClient(r, w) 227 + client, err := h.Oauth.AuthorizedClient(r) 228 228 if err != nil { 229 229 log.Println("failed to get authorized client:", err) 230 230 htmx.HxError(w, http.StatusUnauthorized, "Failed to delete resource, try again later.") ··· 247 247 return 248 248 } 249 249 250 - _, err = client.RepoDeleteRecord(r.Context(), &comatproto.RepoDeleteRecord_Input{ 250 + _, err = comatproto.RepoDeleteRecord(r.Context(), client, &comatproto.RepoDeleteRecord_Input{ 251 251 Collection: yoten.FeedResourceNSID, 252 252 Repo: user.Did, 253 253 Rkey: resource.Rkey, ··· 310 310 SortedResourceTypes: h.ComputedData.SortedResourceTypes, 311 311 }).Render(r.Context(), w) 312 312 case http.MethodPost: 313 - client, err := h.Oauth.AuthorizedClient(r, w) 313 + client, err := h.Oauth.AuthorizedClient(r) 314 314 if err != nil { 315 315 log.Println("failed to get authorized client:", err) 316 316 htmx.HxRedirect(w, "/login") ··· 411 411 feedResource.Link = updatedResource.Link 412 412 } 413 413 414 - ex, _ := client.RepoGetRecord(r.Context(), "", yoten.FeedResourceNSID, user.Did, resource.Rkey) 414 + ex, _ := comatproto.RepoGetRecord(r.Context(), client, "", yoten.FeedResourceNSID, user.Did, resource.Rkey) 415 415 var cid *string 416 416 if ex != nil { 417 417 cid = ex.Cid 418 418 } 419 419 420 - _, err = client.RepoPutRecord(r.Context(), &comatproto.RepoPutRecord_Input{ 420 + _, err = comatproto.RepoPutRecord(r.Context(), client, &comatproto.RepoPutRecord_Input{ 421 421 Collection: yoten.FeedResourceNSID, 422 422 Repo: user.Did, 423 423 Rkey: updatedResource.Rkey,
+9 -12
internal/server/handlers/router.go
··· 5 5 "strings" 6 6 7 7 "github.com/go-chi/chi/v5" 8 - "github.com/gorilla/sessions" 9 8 10 9 "yoten.app/internal/server" 11 10 "yoten.app/internal/server/middleware" 12 - oauthhandler "yoten.app/internal/server/oauth/handler" 13 11 "yoten.app/internal/server/views" 14 12 ) 15 13 ··· 43 41 44 42 func (h *Handler) StandardRouter(mw *middleware.Middleware) http.Handler { 45 43 r := chi.NewRouter() 46 - r.Use(middleware.LoadUnreadNotificationCount(h.Oauth)) 44 + r.Use(mw.LoadUnreadNotificationCount()) 47 45 48 - r.Mount("/", h.OAuthRouter()) 49 46 r.Handle("/static/*", h.HandleStatic()) 47 + 50 48 r.Get("/", h.HandleIndexPage) 51 49 r.Get("/feed", h.HandleStudySessionFeed) 50 + 51 + r.Get("/login", h.Login) 52 + r.Post("/login", h.Login) 53 + r.Post("/logout", h.Logout) 52 54 53 55 r.Route("/friends", func(r chi.Router) { 54 56 r.Use(middleware.AuthMiddleware(h.Oauth)) ··· 125 127 }) 126 128 }) 127 129 130 + r.Mount("/", h.Oauth.Router()) 131 + 128 132 return r 129 133 } 130 134 131 135 func (h *Handler) UserRouter(mw *middleware.Middleware) http.Handler { 132 136 r := chi.NewRouter() 133 137 134 - r.Use(middleware.StripLeadingAt) 135 - r.Use(middleware.LoadUnreadNotificationCount(h.Oauth)) 138 + r.Use(mw.LoadUnreadNotificationCount()) 136 139 137 140 r.Group(func(r chi.Router) { 138 141 r.Use(mw.ResolveIdent()) ··· 153 156 154 157 return r 155 158 } 156 - 157 - func (h *Handler) OAuthRouter() http.Handler { 158 - store := sessions.NewCookieStore([]byte(h.Config.Core.CookieSecret)) 159 - oauth := oauthhandler.New(h.Config, h.Db, store, h.Oauth, h.Posthog) 160 - return oauth.Router() 161 - }
+7 -7
internal/server/handlers/study-session.go
··· 200 200 } 201 201 202 202 func (h *Handler) HandleEditStudySessionPage(w http.ResponseWriter, r *http.Request) { 203 - client, err := h.Oauth.AuthorizedClient(r, w) 203 + client, err := h.Oauth.AuthorizedClient(r) 204 204 if err != nil { 205 205 log.Println("failed to get authorized client:", err) 206 206 htmx.HxRedirect(w, "/login") ··· 342 342 updatedStudySessionRecord.PredefinedActivityName = &updatedStudySession.Activity.Name 343 343 } 344 344 345 - ex, _ := client.RepoGetRecord(r.Context(), "", yoten.FeedSessionNSID, user.Did, updatedStudySession.Rkey) 345 + ex, _ := comatproto.RepoGetRecord(r.Context(), client, "", yoten.FeedSessionNSID, user.Did, updatedStudySession.Rkey) 346 346 var cid *string 347 347 if ex != nil { 348 348 cid = ex.Cid 349 349 } 350 350 351 - _, err = client.RepoPutRecord(r.Context(), &comatproto.RepoPutRecord_Input{ 351 + _, err = comatproto.RepoPutRecord(r.Context(), client, &comatproto.RepoPutRecord_Input{ 352 352 Collection: yoten.FeedSessionNSID, 353 353 Repo: updatedStudySession.Did, 354 354 Rkey: updatedStudySession.Rkey, ··· 393 393 return 394 394 } 395 395 396 - client, err := h.Oauth.AuthorizedClient(r, w) 396 + client, err := h.Oauth.AuthorizedClient(r) 397 397 if err != nil { 398 398 log.Println("failed to get authorized client:", err) 399 399 htmx.HxRedirect(w, "/login") ··· 502 502 newStudySessionRecord.PredefinedActivityName = &newStudySession.Activity.Name 503 503 } 504 504 505 - _, err = client.RepoPutRecord(r.Context(), &comatproto.RepoPutRecord_Input{ 505 + _, err = comatproto.RepoPutRecord(r.Context(), client, &comatproto.RepoPutRecord_Input{ 506 506 Collection: yoten.FeedSessionNSID, 507 507 Repo: newStudySession.Did, 508 508 Rkey: newStudySession.Rkey, ··· 551 551 return 552 552 } 553 553 554 - client, err := h.Oauth.AuthorizedClient(r, w) 554 + client, err := h.Oauth.AuthorizedClient(r) 555 555 if err != nil { 556 556 log.Println("failed to get authorized client:", err) 557 557 htmx.HxError(w, http.StatusUnauthorized, "Failed to delete study session, try again later.") ··· 581 581 return 582 582 } 583 583 584 - _, err = client.RepoDeleteRecord(r.Context(), &comatproto.RepoDeleteRecord_Input{ 584 + _, err = comatproto.RepoDeleteRecord(r.Context(), client, &comatproto.RepoDeleteRecord_Input{ 585 585 Collection: yoten.FeedSessionNSID, 586 586 Repo: user.Did, 587 587 Rkey: rkey,
+8 -16
internal/server/middleware/middleware.go
··· 37 37 38 38 type middlewareFunc func(http.Handler) http.Handler 39 39 40 - func AuthMiddleware(a *oauth.OAuth) middlewareFunc { 40 + func AuthMiddleware(o *oauth.OAuth) middlewareFunc { 41 41 return func(next http.Handler) http.Handler { 42 42 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 43 43 returnURL := "/" ··· 57 57 } 58 58 } 59 59 60 - _, auth, err := a.GetSession(r) 60 + sess, err := o.ResumeSession(r) 61 61 if err != nil { 62 + log.Println("failed to resume session, redirecting...", "err", err, "url", r.URL.String()) 62 63 redirectFunc(w, r) 63 64 return 64 65 } 65 66 66 - if !auth { 67 + if sess == nil { 68 + log.Printf("session is nil, redirecting...") 67 69 redirectFunc(w, r) 68 70 return 69 71 } ··· 101 103 } 102 104 } 103 105 104 - func StripLeadingAt(next http.Handler) http.Handler { 105 - return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 106 - path := req.URL.EscapedPath() 107 - if strings.HasPrefix(path, "/@") { 108 - req.URL.RawPath = "/" + strings.TrimPrefix(path, "/@") 109 - } 110 - next.ServeHTTP(w, req) 111 - }) 112 - } 113 - 114 - func LoadUnreadNotificationCount(o *oauth.OAuth) middlewareFunc { 106 + func (mw Middleware) LoadUnreadNotificationCount() middlewareFunc { 115 107 return func(next http.Handler) http.Handler { 116 108 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 117 - user := o.GetUser(r) 109 + user := mw.oauth.GetUser(r) 118 110 if user == nil { 119 111 next.ServeHTTP(w, r) 120 112 return 121 113 } 122 114 123 - count, err := db.GetUnreadNotificationCount(o.Db, user.Did) 115 + count, err := db.GetUnreadNotificationCount(mw.db, user.Did) 124 116 if err != nil { 125 117 log.Println("failed to get notification count:", err) 126 118 }
+5 -1
internal/server/oauth/consts.go
··· 1 1 package oauth 2 2 3 3 const ( 4 - SessionName = "yoten-oauth-session" 4 + SessionName = "yoten-oauth-session-v2" 5 5 SessionHandle = "handle" 6 6 SessionDid = "did" 7 + SessionId = "id" 7 8 SessionPds = "pds" 8 9 SessionAccessJwt = "accessJwt" 9 10 SessionRefreshJwt = "refreshJwt" 10 11 SessionExpiry = "expiry" 11 12 SessionAuthenticated = "authenticated" 13 + 14 + SessionDpopPrivateJwk = "dpopPrivateJwk" 15 + SessionDpopAuthServerNonce = "dpopAuthServerNonce" 12 16 )
+78
internal/server/oauth/handler.go
··· 1 + package oauth 2 + 3 + import ( 4 + "encoding/json" 5 + "log" 6 + "net/http" 7 + 8 + "github.com/go-chi/chi/v5" 9 + "github.com/lestrrat-go/jwx/v2/jwk" 10 + ) 11 + 12 + func (o *OAuth) Router() http.Handler { 13 + r := chi.NewRouter() 14 + 15 + r.Get("/oauth/client-metadata.json", o.clientMetadata) 16 + r.Get("/oauth/jwks.json", o.jwks) 17 + r.Get("/oauth/callback", o.callback) 18 + 19 + return r 20 + } 21 + 22 + func (o *OAuth) clientMetadata(w http.ResponseWriter, r *http.Request) { 23 + doc := o.ClientApp.Config.ClientMetadata() 24 + doc.JWKSURI = &o.JwksUri 25 + 26 + w.Header().Set("Content-Type", "application/json") 27 + if err := json.NewEncoder(w).Encode(doc); err != nil { 28 + http.Error(w, err.Error(), http.StatusInternalServerError) 29 + return 30 + } 31 + } 32 + 33 + func pubKeyFromJwk(jwks string) (jwk.Key, error) { 34 + k, err := jwk.ParseKey([]byte(jwks)) 35 + if err != nil { 36 + return nil, err 37 + } 38 + pubKey, err := k.PublicKey() 39 + if err != nil { 40 + return nil, err 41 + } 42 + return pubKey, nil 43 + } 44 + 45 + func (o *OAuth) jwks(w http.ResponseWriter, r *http.Request) { 46 + jwks := o.Config.OAuth.Jwks 47 + pubKey, err := pubKeyFromJwk(jwks) 48 + if err != nil { 49 + log.Printf("failed to parse public key: %v", err) 50 + http.Error(w, err.Error(), http.StatusInternalServerError) 51 + return 52 + } 53 + 54 + response := map[string]any{ 55 + "keys": []jwk.Key{pubKey}, 56 + } 57 + 58 + w.Header().Set("Content-Type", "application/json") 59 + w.WriteHeader(http.StatusOK) 60 + json.NewEncoder(w).Encode(response) 61 + } 62 + 63 + func (o *OAuth) callback(w http.ResponseWriter, r *http.Request) { 64 + ctx := r.Context() 65 + 66 + sessData, err := o.ClientApp.ProcessCallback(ctx, r.URL.Query()) 67 + if err != nil { 68 + http.Error(w, err.Error(), http.StatusInternalServerError) 69 + return 70 + } 71 + 72 + if err := o.SaveSession(w, r, sessData); err != nil { 73 + http.Error(w, err.Error(), http.StatusInternalServerError) 74 + return 75 + } 76 + 77 + http.Redirect(w, r, "/", http.StatusFound) 78 + }
-415
internal/server/oauth/handler/handler.go
··· 1 - package handler 2 - 3 - import ( 4 - "encoding/json" 5 - "fmt" 6 - "log" 7 - "net/http" 8 - "net/url" 9 - "strings" 10 - "time" 11 - 12 - comatproto "github.com/bluesky-social/indigo/api/atproto" 13 - lexutil "github.com/bluesky-social/indigo/lex/util" 14 - "github.com/go-chi/chi/v5" 15 - "github.com/gorilla/sessions" 16 - "github.com/posthog/posthog-go" 17 - "tangled.sh/icyphox.sh/atproto-oauth/helpers" 18 - 19 - "yoten.app/api/yoten" 20 - "yoten.app/internal/atproto" 21 - "yoten.app/internal/clients/bsky" 22 - ph "yoten.app/internal/clients/posthog" 23 - "yoten.app/internal/db" 24 - "yoten.app/internal/server/config" 25 - "yoten.app/internal/server/htmx" 26 - "yoten.app/internal/server/middleware" 27 - "yoten.app/internal/server/oauth" 28 - "yoten.app/internal/server/oauth/client" 29 - "yoten.app/internal/server/views" 30 - "yoten.app/internal/types" 31 - ) 32 - 33 - const ( 34 - oauthScope = "atproto transition:generic" 35 - ) 36 - 37 - type OAuthHandler struct { 38 - config *config.Config 39 - db *db.DB 40 - store *sessions.CookieStore 41 - oauth *oauth.OAuth 42 - posthog posthog.Client 43 - } 44 - 45 - func New( 46 - config *config.Config, 47 - db *db.DB, 48 - store *sessions.CookieStore, 49 - oauth *oauth.OAuth, 50 - posthog posthog.Client, 51 - ) *OAuthHandler { 52 - return &OAuthHandler{ 53 - config: config, 54 - db: db, 55 - store: store, 56 - oauth: oauth, 57 - posthog: posthog, 58 - } 59 - } 60 - 61 - func (o *OAuthHandler) Router() http.Handler { 62 - r := chi.NewRouter() 63 - 64 - r.Get("/login", o.HandleLoginPage) 65 - r.Post("/login", o.HandleLoginPage) 66 - 67 - r.With(middleware.AuthMiddleware(o.oauth)).Post("/logout", o.logout) 68 - 69 - r.Get("/oauth/client-metadata.json", o.clientMetadata) 70 - r.Get("/oauth/jwks.json", o.jwks) 71 - r.Get("/oauth/callback", o.callback) 72 - 73 - return r 74 - } 75 - 76 - func (o *OAuthHandler) HandleLoginPage(w http.ResponseWriter, r *http.Request) { 77 - switch r.Method { 78 - case http.MethodGet: 79 - var user *types.User 80 - oauth := o.oauth.GetUser(r) 81 - if oauth != nil { 82 - bskyProfile, err := bsky.GetBskyProfile(oauth.Did) 83 - if err != nil { 84 - log.Println("failed to get bsky profile:", err) 85 - } 86 - user = &types.User{ 87 - OauthUser: *oauth, 88 - BskyProfile: bskyProfile, 89 - } 90 - } 91 - views.LoginPage(views.LoginPageParams{ 92 - User: user, 93 - }).Render(r.Context(), w) 94 - case http.MethodPost: 95 - err := r.ParseForm() 96 - if err != nil { 97 - http.Error(w, "Bad Request", http.StatusBadRequest) 98 - return 99 - } 100 - 101 - handle := r.FormValue("handle") 102 - 103 - // When users copy their handle from bsky.app, it tends to have these 104 - // characters around it: 105 - // \u202a ensures that the handle is always rendered left to right and 106 - // \u202c reverts that so the rest of the page renders however it should 107 - handle = strings.TrimPrefix(handle, "\u202a") 108 - handle = strings.TrimSuffix(handle, "\u202c") 109 - 110 - handle = strings.TrimPrefix(handle, "@") 111 - 112 - idResolver := atproto.DefaultResolver() 113 - resolved, err := idResolver.ResolveIdent(r.Context(), handle) 114 - if err != nil { 115 - log.Println("failed to resolve handle:", err) 116 - htmx.HxError(w, http.StatusBadGateway, fmt.Sprintf("Failed to resolve identity - '%s' is an invalid handle.", handle)) 117 - return 118 - } 119 - 120 - cli := o.oauth.ClientMetadata() 121 - oauthClient, err := client.NewClient( 122 - cli.ClientID, 123 - o.config.OAuth.Jwks, 124 - cli.RedirectURIs[0], 125 - ) 126 - if err != nil { 127 - log.Println("failed to create oauth client:", err) 128 - htmx.HxError(w, http.StatusUnauthorized, "Failed to authenticate. Try again later.") 129 - return 130 - } 131 - 132 - authServer, err := oauthClient.ResolvePdsAuthServer(r.Context(), resolved.PDSEndpoint()) 133 - if err != nil { 134 - log.Println("failed to resolve auth server:", err) 135 - htmx.HxError(w, http.StatusUnauthorized, "Failed to authenticate. Try again later.") 136 - return 137 - } 138 - 139 - authMeta, err := oauthClient.FetchAuthServerMetadata(r.Context(), authServer) 140 - if err != nil { 141 - log.Println("failed to fetch auth server metadata:", err) 142 - htmx.HxError(w, http.StatusUnauthorized, "Failed to authenticate. Try again later.") 143 - return 144 - } 145 - 146 - dpopKey, err := helpers.GenerateKey(nil) 147 - if err != nil { 148 - log.Println("failed to generate dpop key:", err) 149 - htmx.HxError(w, http.StatusUnauthorized, "Failed to authenticate. Try again later.") 150 - return 151 - } 152 - 153 - dpopKeyJson, err := json.Marshal(dpopKey) 154 - if err != nil { 155 - log.Println("failed to marshal dpop key:", err) 156 - htmx.HxError(w, http.StatusUnauthorized, "Failed to authenticate. Try again later.") 157 - return 158 - } 159 - 160 - parResp, err := oauthClient.SendParAuthRequest(r.Context(), authServer, authMeta, handle, oauthScope, dpopKey) 161 - if err != nil { 162 - log.Println("failed to send par auth request:", err) 163 - htmx.HxError(w, http.StatusUnauthorized, "Failed to authenticate. Try again later.") 164 - return 165 - } 166 - 167 - err = db.SaveOAuthRequest(o.db, db.OAuthRequest{ 168 - Did: resolved.DID.String(), 169 - PdsUrl: resolved.PDSEndpoint(), 170 - Handle: handle, 171 - AuthserverIss: authMeta.Issuer, 172 - PkceVerifier: parResp.PkceVerifier, 173 - DpopAuthserverNonce: parResp.DpopAuthserverNonce, 174 - DpopPrivateJwk: string(dpopKeyJson), 175 - State: parResp.State, 176 - }) 177 - if err != nil { 178 - log.Println("failed to save oauth request:", err) 179 - htmx.HxError(w, http.StatusUnauthorized, "Failed to authenticate. Try again later.") 180 - return 181 - } 182 - 183 - if !o.config.Core.Dev { 184 - err := o.posthog.Enqueue(posthog.Capture{ 185 - DistinctId: resolved.DID.String(), 186 - Event: ph.UserSignInInitiatedEvent, 187 - }) 188 - if err != nil { 189 - log.Println("failed to enqueue posthog event:", err) 190 - } 191 - } 192 - 193 - u, _ := url.Parse(authMeta.AuthorizationEndpoint) 194 - query := url.Values{} 195 - query.Add("client_id", cli.ClientID) 196 - query.Add("request_uri", parResp.RequestUri) 197 - u.RawQuery = query.Encode() 198 - htmx.HxRedirect(w, u.String()) 199 - } 200 - } 201 - 202 - func (o *OAuthHandler) logout(w http.ResponseWriter, r *http.Request) { 203 - did := o.oauth.GetDid(r) 204 - err := o.oauth.ClearSession(r, w) 205 - if err != nil { 206 - log.Println("failed to clear session:", err) 207 - http.Redirect(w, r, "/", http.StatusFound) 208 - return 209 - } 210 - 211 - if !o.config.Core.Dev && did != "" { 212 - err := o.posthog.Enqueue(posthog.Capture{ 213 - DistinctId: did, 214 - Event: ph.UserLoggedOutEvent, 215 - }) 216 - if err != nil { 217 - log.Println("failed to enqueue posthog event:", err) 218 - } 219 - } 220 - 221 - htmx.HxRedirect(w, "/login") 222 - } 223 - 224 - func (o *OAuthHandler) jwks(w http.ResponseWriter, r *http.Request) { 225 - jwks := o.config.OAuth.Jwks 226 - k, err := helpers.ParseJWKFromBytes([]byte(jwks)) 227 - if err != nil { 228 - log.Printf("failed to parse jwks: %v", err) 229 - http.Error(w, "Internal Server Error", 500) 230 - } 231 - 232 - pubKey, err := k.PublicKey() 233 - if err != nil { 234 - log.Printf("failed to parse jwks public key: %v", err) 235 - http.Error(w, "Internal Server Error", 500) 236 - } 237 - 238 - w.Header().Set("Content-Type", "application/json") 239 - w.WriteHeader(http.StatusOK) 240 - json.NewEncoder(w).Encode(helpers.CreateJwksResponseObject(pubKey)) 241 - } 242 - 243 - func (o *OAuthHandler) clientMetadata(w http.ResponseWriter, r *http.Request) { 244 - w.Header().Set("Content-Type", "application/json") 245 - w.WriteHeader(http.StatusOK) 246 - json.NewEncoder(w).Encode(o.oauth.ClientMetadata()) 247 - } 248 - 249 - func (o *OAuthHandler) callback(w http.ResponseWriter, r *http.Request) { 250 - state := r.FormValue("state") 251 - 252 - oauthRequest, err := db.GetOAuthRequestByState(o.db, state) 253 - if err != nil { 254 - log.Println("failed to get oauth request:", err) 255 - htmx.HxError(w, http.StatusUnauthorized, "Failed to authenticate. Try again later.") 256 - return 257 - } 258 - 259 - defer func() { 260 - err := db.DeleteOAuthRequestByState(o.db, state) 261 - if err != nil { 262 - log.Printf("failed to delete oauth request for state '%s': %v", state, err) 263 - } 264 - }() 265 - 266 - callbackErr := r.FormValue("error") 267 - errorDescription := r.FormValue("error_description") 268 - if callbackErr != "" || errorDescription != "" { 269 - log.Printf("oauth callback error: %s, %s", callbackErr, errorDescription) 270 - htmx.HxError(w, http.StatusUnauthorized, "Failed to authenticate. Try again later.") 271 - return 272 - } 273 - 274 - iss := r.FormValue("iss") 275 - if iss == "" { 276 - log.Println("missing iss for state: ", state) 277 - htmx.HxError(w, http.StatusUnauthorized, "Failed to authenticate. Try again later.") 278 - return 279 - } 280 - 281 - code := r.FormValue("code") 282 - if code == "" { 283 - log.Println("missing code for state: ", state) 284 - htmx.HxError(w, http.StatusUnauthorized, "Failed to authenticate. Try again later.") 285 - return 286 - } 287 - 288 - if iss != oauthRequest.AuthserverIss { 289 - log.Println("mismatched iss:", iss, "!=", oauthRequest.AuthserverIss, "for state:", state) 290 - htmx.HxError(w, http.StatusUnauthorized, "Failed to authenticate. Try again later.") 291 - return 292 - } 293 - cli := o.oauth.ClientMetadata() 294 - oauthClient, err := client.NewClient( 295 - cli.ClientID, 296 - o.config.OAuth.Jwks, 297 - cli.RedirectURIs[0], 298 - ) 299 - if err != nil { 300 - log.Println("failed to create oauth client:", err) 301 - htmx.HxError(w, http.StatusUnauthorized, "Failed to authenticate. Try again later.") 302 - return 303 - } 304 - 305 - jwk, err := helpers.ParseJWKFromBytes([]byte(oauthRequest.DpopPrivateJwk)) 306 - if err != nil { 307 - log.Println("failed to parse jwk:", err) 308 - htmx.HxError(w, http.StatusUnauthorized, "Failed to authenticate. Try again later.") 309 - return 310 - } 311 - 312 - tokenResp, err := oauthClient.InitialTokenRequest( 313 - r.Context(), 314 - code, 315 - oauthRequest.AuthserverIss, 316 - oauthRequest.PkceVerifier, 317 - oauthRequest.DpopAuthserverNonce, 318 - jwk, 319 - ) 320 - if err != nil { 321 - log.Println("failed to get token:", err) 322 - htmx.HxError(w, http.StatusUnauthorized, "Failed to authenticate. Try again later.") 323 - return 324 - } 325 - 326 - if tokenResp.Scope != oauthScope { 327 - log.Println("oauth scope doesn't match:", tokenResp.Scope) 328 - htmx.HxError(w, http.StatusUnauthorized, "Failed to authenticate. Try again later.") 329 - return 330 - } 331 - 332 - userSession, err := o.oauth.SaveSession(w, r, oauthRequest, tokenResp) 333 - if err != nil { 334 - log.Println("failed to save user session:", err) 335 - htmx.HxError(w, http.StatusUnauthorized, "Failed to authenticate. Try again later.") 336 - return 337 - } 338 - 339 - if !o.config.Core.Dev { 340 - err = o.posthog.Enqueue(posthog.Capture{ 341 - DistinctId: oauthRequest.Did, 342 - Event: ph.UserSignInSuccessEvent, 343 - }) 344 - if err != nil { 345 - log.Println("failed to enqueue posthog event:", err) 346 - } 347 - } 348 - 349 - xrpcClient, err := o.oauth.AuthorizedClientFromSession(*userSession, r, w) 350 - if err != nil { 351 - log.Println("failed to retrieve authorized client:", err) 352 - htmx.HxError(w, http.StatusUnauthorized, "Failed to authenticate. Try again later.") 353 - return 354 - } 355 - 356 - ex, _ := xrpcClient.RepoGetRecord(r.Context(), "", yoten.ActorProfileNSID, oauthRequest.Did, "self") 357 - var cid *string 358 - if ex != nil { 359 - cid = ex.Cid 360 - } 361 - 362 - // This should only occur once per account 363 - if ex == nil { 364 - createdAt := time.Now().Format(time.RFC3339) 365 - atresp, err := xrpcClient.RepoPutRecord(r.Context(), &comatproto.RepoPutRecord_Input{ 366 - Collection: yoten.ActorProfileNSID, 367 - Repo: oauthRequest.Did, 368 - Rkey: "self", 369 - Record: &lexutil.LexiconTypeDecoder{ 370 - Val: &yoten.ActorProfile{ 371 - DisplayName: oauthRequest.Handle, 372 - Description: db.ToPtr(""), 373 - Languages: make([]string, 0), 374 - Location: db.ToPtr(""), 375 - CreatedAt: createdAt, 376 - }}, 377 - SwapRecord: cid, 378 - }) 379 - if err != nil { 380 - log.Println("failed to create record:", err) 381 - htmx.HxError(w, http.StatusInternalServerError, "Failed to announce profile creation, try again later") 382 - return 383 - } 384 - 385 - log.Println("created profile record:", atresp.Uri) 386 - if !o.config.Core.Dev { 387 - properties := posthog.NewProperties(). 388 - Set("display_name", oauthRequest.Handle). 389 - Set("language_count", 0). 390 - Set("$set_once", posthog.NewProperties(). 391 - Set("initial_did", oauthRequest.Did). 392 - Set("initial_handle", oauthRequest.Handle). 393 - Set("created_at", createdAt), 394 - ) 395 - 396 - err = o.posthog.Enqueue(posthog.Identify{ 397 - DistinctId: oauthRequest.Did, 398 - Properties: properties, 399 - }) 400 - if err != nil { 401 - log.Println("failed to enqueue posthog identify event:", err) 402 - } 403 - 404 - err = o.posthog.Enqueue(posthog.Capture{ 405 - DistinctId: oauthRequest.Did, 406 - Event: ph.ProfileRecordCreatedEvent, 407 - }) 408 - if err != nil { 409 - log.Println("failed to enqueue posthog event:", err) 410 - } 411 - } 412 - } 413 - 414 - http.Redirect(w, r, "/", http.StatusFound) 415 - }
+148 -215
internal/server/oauth/oauth.go
··· 1 1 package oauth 2 2 3 3 import ( 4 + "errors" 4 5 "fmt" 5 - "log" 6 6 "net/http" 7 - "net/url" 8 7 "time" 9 8 9 + comatproto "github.com/bluesky-social/indigo/api/atproto" 10 + "github.com/bluesky-social/indigo/atproto/auth/oauth" 11 + atpclient "github.com/bluesky-social/indigo/atproto/client" 12 + "github.com/bluesky-social/indigo/atproto/syntax" 13 + xrpc "github.com/bluesky-social/indigo/xrpc" 10 14 "github.com/gorilla/sessions" 11 - oauth "tangled.sh/icyphox.sh/atproto-oauth" 12 - "tangled.sh/icyphox.sh/atproto-oauth/helpers" 13 15 14 - xrpc "yoten.app/internal/atproto" 15 - "yoten.app/internal/db" 16 16 "yoten.app/internal/server/config" 17 - "yoten.app/internal/server/oauth/client" 18 17 "yoten.app/internal/types" 19 18 ) 20 19 21 20 type OAuth struct { 22 - Store *sessions.CookieStore 23 - Db *db.DB 24 - Config *config.Config 21 + ClientApp *oauth.ClientApp 22 + SessionStore *sessions.CookieStore 23 + Config *config.Config 24 + JwksUri string 25 25 } 26 26 27 - func NewOAuth(db *db.DB, config *config.Config) *OAuth { 28 - return &OAuth{ 29 - Store: sessions.NewCookieStore([]byte(config.Core.CookieSecret)), 30 - Db: db, 31 - Config: config, 27 + func New(config *config.Config) (*OAuth, error) { 28 + var oauthConfig oauth.ClientConfig 29 + var clientUri string 30 + 31 + if config.Core.Dev { 32 + clientUri = "http://127.0.0.1:" + config.Core.Port 33 + callbackUri := clientUri + "/oauth/callback" 34 + oauthConfig = oauth.NewLocalhostConfig(callbackUri, []string{"atproto", "transition:generic"}) 35 + } else { 36 + clientUri = config.Core.Host 37 + clientId := fmt.Sprintf("%s/oauth/client-metadata.json", clientUri) 38 + callbackUri := clientUri + "/oauth/callback" 39 + oauthConfig = oauth.NewPublicConfig(clientId, callbackUri, []string{"atproto", "transition:generic"}) 32 40 } 33 - } 41 + 42 + jwksUri := clientUri + "/oauth/jwks.json" 34 43 35 - func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, oreq db.OAuthRequest, oresp *oauth.TokenResponse) (*sessions.Session, error) { 36 - // Save did in user session. 37 - userSession, err := o.Store.Get(r, SessionName) 44 + authStore, err := NewRedisStore(config.Redis.ToURL()) 38 45 if err != nil { 39 46 return nil, err 40 47 } 41 48 42 - userSession.Values[SessionDid] = oreq.Did 43 - userSession.Values[SessionHandle] = oreq.Handle 44 - userSession.Values[SessionPds] = oreq.PdsUrl 45 - userSession.Values[SessionAuthenticated] = true 46 - err = userSession.Save(r, w) 47 - if err != nil { 48 - return nil, fmt.Errorf("failed to save user session: %w", err) 49 - } 49 + sessStore := sessions.NewCookieStore([]byte(config.Core.CookieSecret)) 50 50 51 - // Save the whole thing in the db. 52 - session := db.OAuthSession{ 53 - Did: oreq.Did, 54 - Handle: oreq.Handle, 55 - PdsUrl: oreq.PdsUrl, 56 - DpopAuthserverNonce: oreq.DpopAuthserverNonce, 57 - AuthServerIss: oreq.AuthserverIss, 58 - DpopPrivateJwk: oreq.DpopPrivateJwk, 59 - AccessJwt: oresp.AccessToken, 60 - RefreshJwt: oresp.RefreshToken, 61 - Expiry: time.Now().Add(time.Duration(oresp.ExpiresIn) * time.Second).Format(time.RFC3339), 62 - } 51 + return &OAuth{ 52 + ClientApp: oauth.NewClientApp(&oauthConfig, authStore), 53 + Config: config, 54 + SessionStore: sessStore, 55 + JwksUri: jwksUri, 56 + }, nil 63 57 64 - return userSession, db.SaveOAuthSession(o.Db, session) 65 58 } 66 59 67 - func (o *OAuth) ClearSession(r *http.Request, w http.ResponseWriter) error { 68 - userSession, err := o.Store.Get(r, SessionName) 60 + func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, sessionData *oauth.ClientSessionData) error { 61 + // Save did in user session. 62 + userSession, err := o.SessionStore.Get(r, SessionName) 69 63 if err != nil { 70 - return fmt.Errorf("failed to get user session: %w", err) 71 - } 72 - if userSession.IsNew { 73 - return fmt.Errorf("user session is new") 64 + return err 74 65 } 75 66 76 - did := userSession.Values[SessionDid].(string) 77 - 78 - err = db.DeleteOAuthSessionByDid(o.Db, did) 67 + userSession.Values[SessionDid] = sessionData.AccountDID.String() 68 + userSession.Values[SessionPds] = sessionData.HostURL 69 + userSession.Values[SessionId] = sessionData.SessionID 70 + userSession.Values[SessionAuthenticated] = true 71 + err = userSession.Save(r, w) 79 72 if err != nil { 80 - return fmt.Errorf("failed to delete oauth session: %w", err) 73 + return fmt.Errorf("failed to save user session: %w", err) 81 74 } 82 75 83 - userSession.Options.MaxAge = -1 84 - 85 - return userSession.Save(r, w) 76 + return nil 86 77 } 87 78 88 - func (o *OAuth) CheckSessionAuth(userSession sessions.Session, r *http.Request) (*db.OAuthSession, bool, error) { 89 - did := userSession.Values[SessionDid].(string) 90 - auth := userSession.Values[SessionAuthenticated].(bool) 91 - 92 - session, err := db.GetOAuthSessionByDid(o.Db, did) 79 + func (o *OAuth) ResumeSession(r *http.Request) (*oauth.ClientSession, error) { 80 + userSession, err := o.SessionStore.Get(r, SessionName) 93 81 if err != nil { 94 - return nil, false, fmt.Errorf("failed to get oauth session: %w", err) 82 + return nil, fmt.Errorf("failed to retrieve user session: %w", err) 83 + } 84 + if userSession.IsNew { 85 + return nil, fmt.Errorf("no session available for user") 95 86 } 96 87 97 - expiry, err := time.Parse(time.RFC3339, session.Expiry) 88 + d := userSession.Values[SessionDid].(string) 89 + sessionDid, err := syntax.ParseDID(d) 98 90 if err != nil { 99 - return nil, false, fmt.Errorf("failed to parse expiry time: %w", err) 91 + return nil, fmt.Errorf("malformed DID in session cookie '%s': %w", d, err) 100 92 } 101 93 102 - if expiry.Sub(time.Now()) <= 5*time.Minute { 103 - privateJwk, err := helpers.ParseJWKFromBytes([]byte(session.DpopPrivateJwk)) 104 - if err != nil { 105 - return nil, false, err 106 - } 107 - 108 - self := o.ClientMetadata() 94 + sessionId := userSession.Values[SessionId].(string) 109 95 110 - oauthClient, err := client.NewClient( 111 - self.ClientID, 112 - o.Config.OAuth.Jwks, 113 - self.RedirectURIs[0], 114 - ) 115 - 116 - if err != nil { 117 - return nil, false, err 118 - } 119 - 120 - resp, err := oauthClient.RefreshTokenRequest(r.Context(), session.RefreshJwt, session.AuthServerIss, session.DpopAuthserverNonce, privateJwk) 121 - if err != nil { 122 - log.Printf("failed to refresh token for did '%s', deleting session: %v", did, err) 123 - if delErr := db.DeleteOAuthSessionByDid(o.Db, did); delErr != nil { 124 - log.Printf("failed to delete stale oauth session for did '%s': %v", did, delErr) 125 - } 126 - return nil, false, fmt.Errorf("session expired and could not be refreshed: %w", err) 127 - } 128 - 129 - newExpiry := time.Now().Add(time.Duration(resp.ExpiresIn) * time.Second).Format(time.RFC3339) 130 - err = db.RefreshOAuthSession(o.Db, did, resp.AccessToken, resp.RefreshToken, newExpiry) 131 - if err != nil { 132 - return nil, false, fmt.Errorf("failed to refresh oauth session: %w", err) 133 - } 134 - 135 - // Update the current session. 136 - session.AccessJwt = resp.AccessToken 137 - session.RefreshJwt = resp.RefreshToken 138 - session.DpopAuthserverNonce = resp.DpopAuthserverNonce 139 - session.Expiry = newExpiry 96 + clientSession, err := o.ClientApp.ResumeSession(r.Context(), sessionDid, sessionId) 97 + if err != nil { 98 + return nil, fmt.Errorf("failed to resume session: %w", err) 140 99 } 141 100 142 - return session, auth, nil 101 + return clientSession, nil 143 102 } 144 103 145 - func (o *OAuth) GetSession(r *http.Request) (*db.OAuthSession, bool, error) { 146 - userSession, err := o.Store.Get(r, SessionName) 104 + func (o *OAuth) DeleteSession(w http.ResponseWriter, r *http.Request) error { 105 + userSession, err := o.SessionStore.Get(r, SessionName) 147 106 if err != nil { 148 - return nil, false, fmt.Errorf("failed to get user session: %w", err) 107 + return fmt.Errorf("failed to retrieve user session: %w", err) 149 108 } 150 109 if userSession.IsNew { 151 - return nil, false, fmt.Errorf("user session is new") 110 + return fmt.Errorf("no session available for user") 152 111 } 153 112 154 - session, auth, err := o.CheckSessionAuth(*userSession, r) 113 + d := userSession.Values[SessionDid].(string) 114 + sessionDid, err := syntax.ParseDID(d) 155 115 if err != nil { 156 - return nil, false, fmt.Errorf("failed to check user session auth: %w", err) 116 + return fmt.Errorf("malformed DID in session cookie '%s': %w", d, err) 157 117 } 158 118 159 - return session, auth, nil 119 + sessionId := userSession.Values[SessionId].(string) 120 + 121 + // Delete the session. 122 + err1 := o.ClientApp.Logout(r.Context(), sessionDid, sessionId) 123 + 124 + // Remove the cookie. 125 + userSession.Options.MaxAge = -1 126 + err2 := o.SessionStore.Save(r, w, userSession) 127 + 128 + return errors.Join(err1, err2) 160 129 } 161 130 162 - func (a *OAuth) GetUser(r *http.Request) *types.OauthUser { 163 - clientSession, err := a.Store.Get(r, SessionName) 131 + func (o *OAuth) GetUser(r *http.Request) *types.OauthUser { 132 + clientSession, err := o.SessionStore.Get(r, SessionName) 164 133 if err != nil || clientSession.IsNew { 165 134 return nil 166 135 } 167 136 168 137 return &types.OauthUser{ 169 - Handle: clientSession.Values[SessionHandle].(string), 170 - Did: clientSession.Values[SessionDid].(string), 171 - Pds: clientSession.Values[SessionPds].(string), 138 + Did: clientSession.Values[SessionDid].(string), 139 + Pds: clientSession.Values[SessionPds].(string), 172 140 } 173 141 } 174 142 175 - func (a *OAuth) GetDid(r *http.Request) string { 176 - clientSession, err := a.Store.Get(r, SessionName) 177 - if err != nil || clientSession.IsNew { 178 - return "" 143 + func (o *OAuth) GetDid(r *http.Request) string { 144 + if u := o.GetUser(r); u != nil { 145 + return u.Did 179 146 } 180 147 181 - return clientSession.Values[SessionDid].(string) 148 + return "" 182 149 } 183 150 184 - func (o *OAuth) AuthorizedClientFromSession(userSession sessions.Session, r *http.Request, w http.ResponseWriter) (*xrpc.Client, error) { 185 - session, auth, err := o.CheckSessionAuth(userSession, r) 151 + func (o *OAuth) AuthorizedClient(r *http.Request) (*atpclient.APIClient, error) { 152 + session, err := o.ResumeSession(r) 186 153 if err != nil { 187 - o.ClearSession(r, w) 188 - return nil, fmt.Errorf("failed to get session: %w", err) 189 - } 190 - if !auth { 191 - return nil, fmt.Errorf("not authorized") 154 + return nil, fmt.Errorf("failed to retrieve session: %w", err) 192 155 } 193 156 194 - client := &oauth.XrpcClient{ 195 - OnDpopPdsNonceChanged: func(did, newNonce string) { 196 - err := db.UpdateDpopPdsNonce(o.Db, did, newNonce) 197 - if err != nil { 198 - log.Printf("failed to update dpop pds nonce: %v", err) 199 - } 200 - }, 201 - } 157 + return session.APIClient(), nil 158 + } 202 159 203 - privateJwk, err := helpers.ParseJWKFromBytes([]byte(session.DpopPrivateJwk)) 204 - if err != nil { 205 - return nil, fmt.Errorf("failed to parse private jwk: %w", err) 206 - } 160 + // this is a higher level abstraction on ServerGetServiceAuth 161 + type ServiceClientOpts struct { 162 + service string 163 + exp int64 164 + lxm string 165 + dev bool 166 + } 207 167 208 - xrpcClient := xrpc.NewClient(client, &oauth.XrpcAuthedRequestArgs{ 209 - Did: session.Did, 210 - PdsUrl: session.PdsUrl, 211 - DpopPdsNonce: session.PdsUrl, 212 - AccessToken: session.AccessJwt, 213 - Issuer: session.AuthServerIss, 214 - DpopPrivateJwk: privateJwk, 215 - }) 168 + type ServiceClientOpt func(*ServiceClientOpts) 216 169 217 - return xrpcClient, nil 170 + func WithService(service string) ServiceClientOpt { 171 + return func(s *ServiceClientOpts) { 172 + s.service = service 173 + } 218 174 } 219 175 220 - func (o *OAuth) AuthorizedClient(r *http.Request, w http.ResponseWriter) (*xrpc.Client, error) { 221 - session, auth, err := o.GetSession(r) 222 - if err != nil { 223 - o.ClearSession(r, w) 224 - return nil, fmt.Errorf("failed to get session: %w", err) 176 + // Specify the Duration in seconds for the expiry of this token 177 + // 178 + // The time of expiry is calculated as time.Now().Unix() + exp 179 + func WithExp(exp int64) ServiceClientOpt { 180 + return func(s *ServiceClientOpts) { 181 + s.exp = time.Now().Unix() + exp 225 182 } 226 - if !auth { 227 - return nil, fmt.Errorf("not authorized") 228 - } 183 + } 229 184 230 - client := &oauth.XrpcClient{ 231 - OnDpopPdsNonceChanged: func(did, newNonce string) { 232 - err := db.UpdateDpopPdsNonce(o.Db, did, newNonce) 233 - if err != nil { 234 - log.Printf("failed to update dpop pds nonce: %v", err) 235 - } 236 - }, 185 + func WithLxm(lxm string) ServiceClientOpt { 186 + return func(s *ServiceClientOpts) { 187 + s.lxm = lxm 237 188 } 189 + } 238 190 239 - privateJwk, err := helpers.ParseJWKFromBytes([]byte(session.DpopPrivateJwk)) 240 - if err != nil { 241 - return nil, fmt.Errorf("failed to parse private jwk: %w", err) 191 + func WithDev(dev bool) ServiceClientOpt { 192 + return func(s *ServiceClientOpts) { 193 + s.dev = dev 242 194 } 243 - 244 - xrpcClient := xrpc.NewClient(client, &oauth.XrpcAuthedRequestArgs{ 245 - Did: session.Did, 246 - PdsUrl: session.PdsUrl, 247 - DpopPdsNonce: session.PdsUrl, 248 - AccessToken: session.AccessJwt, 249 - Issuer: session.AuthServerIss, 250 - DpopPrivateJwk: privateJwk, 251 - }) 252 - 253 - return xrpcClient, nil 254 195 } 255 196 256 - type ClientMetadata struct { 257 - ClientID string `json:"client_id"` 258 - ClientName string `json:"client_name"` 259 - SubjectType string `json:"subject_type"` 260 - ClientURI string `json:"client_uri"` 261 - RedirectURIs []string `json:"redirect_uris"` 262 - GrantTypes []string `json:"grant_types"` 263 - ResponseTypes []string `json:"response_types"` 264 - ApplicationType string `json:"application_type"` 265 - DpopBoundAccessTokens bool `json:"dpop_bound_access_tokens"` 266 - JwksURI string `json:"jwks_uri"` 267 - Scope string `json:"scope"` 268 - TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"` 269 - TokenEndpointAuthSigningAlg string `json:"token_endpoint_auth_signing_alg"` 197 + func (s *ServiceClientOpts) Audience() string { 198 + return fmt.Sprintf("did:web:%s", s.service) 270 199 } 271 200 272 - func (o *OAuth) ClientMetadata() ClientMetadata { 273 - makeRedirectURIs := func(c string) []string { 274 - return []string{fmt.Sprintf("%s/oauth/callback", c)} 201 + func (s *ServiceClientOpts) Host() string { 202 + scheme := "https://" 203 + if s.dev { 204 + scheme = "http://" 275 205 } 276 206 277 - clientURI := o.Config.Core.Host 278 - clientID := fmt.Sprintf("%s/oauth/client-metadata.json", clientURI) 279 - redirectURIs := makeRedirectURIs(clientURI) 207 + return scheme + s.service 208 + } 280 209 281 - if o.Config.Core.Dev { 282 - clientURI = "http://127.0.0.1:8080" 283 - redirectURIs = makeRedirectURIs(clientURI) 210 + func (o *OAuth) ServiceClient(r *http.Request, os ...ServiceClientOpt) (*xrpc.Client, error) { 211 + opts := ServiceClientOpts{} 212 + for _, o := range os { 213 + o(&opts) 214 + } 284 215 285 - query := url.Values{} 286 - query.Add("redirect_uri", redirectURIs[0]) 287 - query.Add("scope", "atproto transition:generic") 288 - clientID = fmt.Sprintf("http://localhost?%s", query.Encode()) 216 + client, err := o.AuthorizedClient(r) 217 + if err != nil { 218 + return nil, err 289 219 } 290 220 291 - jwksURI := fmt.Sprintf("%s/oauth/jwks.json", clientURI) 221 + // force expiry to atleast 60 seconds in the future 222 + sixty := time.Now().Unix() + 60 223 + if opts.exp < sixty { 224 + opts.exp = sixty 225 + } 292 226 293 - return ClientMetadata{ 294 - ClientID: clientID, 295 - ClientName: "Yoten", 296 - SubjectType: "public", 297 - ClientURI: clientURI, 298 - RedirectURIs: redirectURIs, 299 - GrantTypes: []string{"authorization_code", "refresh_token"}, 300 - ResponseTypes: []string{"code"}, 301 - ApplicationType: "web", 302 - DpopBoundAccessTokens: true, 303 - JwksURI: jwksURI, 304 - Scope: "atproto transition:generic", 305 - TokenEndpointAuthMethod: "private_key_jwt", 306 - TokenEndpointAuthSigningAlg: "ES256", 227 + resp, err := comatproto.ServerGetServiceAuth(r.Context(), client, opts.Audience(), opts.exp, opts.lxm) 228 + if err != nil { 229 + return nil, err 307 230 } 231 + 232 + return &xrpc.Client{ 233 + Auth: &xrpc.AuthInfo{ 234 + AccessJwt: resp.Token, 235 + }, 236 + Host: opts.Host(), 237 + Client: &http.Client{ 238 + Timeout: time.Second * 5, 239 + }, 240 + }, nil 308 241 }
+148
internal/server/oauth/store.go
··· 1 + package oauth 2 + 3 + import ( 4 + "context" 5 + "encoding/json" 6 + "fmt" 7 + "time" 8 + 9 + "github.com/bluesky-social/indigo/atproto/auth/oauth" 10 + "github.com/bluesky-social/indigo/atproto/syntax" 11 + "github.com/redis/go-redis/v9" 12 + ) 13 + 14 + // redis-backed implementation of ClientAuthStore. 15 + type RedisStore struct { 16 + client *redis.Client 17 + SessionTTL time.Duration 18 + AuthRequestTTL time.Duration 19 + } 20 + 21 + var _ oauth.ClientAuthStore = &RedisStore{} 22 + 23 + func NewRedisStore(redisURL string) (*RedisStore, error) { 24 + fmt.Println(redisURL) 25 + opts, err := redis.ParseURL(redisURL) 26 + if err != nil { 27 + return nil, fmt.Errorf("failed to parse redis URL: %w", err) 28 + } 29 + 30 + client := redis.NewClient(opts) 31 + 32 + // Test the connection. 33 + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 34 + defer cancel() 35 + 36 + if err := client.Ping(ctx).Err(); err != nil { 37 + return nil, fmt.Errorf("failed to connect to redis: %w", err) 38 + } 39 + 40 + return &RedisStore{ 41 + client: client, 42 + SessionTTL: 30 * 24 * time.Hour, // 30 days 43 + AuthRequestTTL: 10 * time.Minute, // 10 minutes 44 + }, nil 45 + } 46 + 47 + func (r *RedisStore) Close() error { 48 + return r.client.Close() 49 + } 50 + 51 + func sessionKey(did syntax.DID, sessionID string) string { 52 + return fmt.Sprintf("oauth:session:%s:%s", did, sessionID) 53 + } 54 + 55 + func authRequestKey(state string) string { 56 + return fmt.Sprintf("oauth:auth_request:%s", state) 57 + } 58 + 59 + func (r *RedisStore) GetSession(ctx context.Context, did syntax.DID, sessionID string) (*oauth.ClientSessionData, error) { 60 + key := sessionKey(did, sessionID) 61 + data, err := r.client.Get(ctx, key).Bytes() 62 + if err == redis.Nil { 63 + return nil, fmt.Errorf("session not found: %s", did) 64 + } 65 + if err != nil { 66 + return nil, fmt.Errorf("failed to get session: %w", err) 67 + } 68 + 69 + var sess oauth.ClientSessionData 70 + if err := json.Unmarshal(data, &sess); err != nil { 71 + return nil, fmt.Errorf("failed to unmarshal session: %w", err) 72 + } 73 + 74 + return &sess, nil 75 + } 76 + 77 + func (r *RedisStore) SaveSession(ctx context.Context, sess oauth.ClientSessionData) error { 78 + key := sessionKey(sess.AccountDID, sess.SessionID) 79 + 80 + data, err := json.Marshal(sess) 81 + if err != nil { 82 + return fmt.Errorf("failed to marshal session: %w", err) 83 + } 84 + 85 + if err := r.client.Set(ctx, key, data, r.SessionTTL).Err(); err != nil { 86 + return fmt.Errorf("failed to save session: %w", err) 87 + } 88 + 89 + return nil 90 + } 91 + 92 + func (r *RedisStore) DeleteSession(ctx context.Context, did syntax.DID, sessionID string) error { 93 + key := sessionKey(did, sessionID) 94 + if err := r.client.Del(ctx, key).Err(); err != nil { 95 + return fmt.Errorf("failed to delete session: %w", err) 96 + } 97 + return nil 98 + } 99 + 100 + func (r *RedisStore) GetAuthRequestInfo(ctx context.Context, state string) (*oauth.AuthRequestData, error) { 101 + key := authRequestKey(state) 102 + data, err := r.client.Get(ctx, key).Bytes() 103 + if err == redis.Nil { 104 + return nil, fmt.Errorf("request info not found: %s", state) 105 + } 106 + if err != nil { 107 + return nil, fmt.Errorf("failed to get auth request: %w", err) 108 + } 109 + 110 + var req oauth.AuthRequestData 111 + if err := json.Unmarshal(data, &req); err != nil { 112 + return nil, fmt.Errorf("failed to unmarshal auth request: %w", err) 113 + } 114 + 115 + return &req, nil 116 + } 117 + 118 + func (r *RedisStore) SaveAuthRequestInfo(ctx context.Context, info oauth.AuthRequestData) error { 119 + key := authRequestKey(info.State) 120 + 121 + // check if already exists (to match MemStore behavior) 122 + exists, err := r.client.Exists(ctx, key).Result() 123 + if err != nil { 124 + return fmt.Errorf("failed to check auth request existence: %w", err) 125 + } 126 + if exists > 0 { 127 + return fmt.Errorf("auth request already saved for state %s", info.State) 128 + } 129 + 130 + data, err := json.Marshal(info) 131 + if err != nil { 132 + return fmt.Errorf("failed to marshal auth request: %w", err) 133 + } 134 + 135 + if err := r.client.Set(ctx, key, data, r.AuthRequestTTL).Err(); err != nil { 136 + return fmt.Errorf("failed to save auth request: %w", err) 137 + } 138 + 139 + return nil 140 + } 141 + 142 + func (r *RedisStore) DeleteAuthRequestInfo(ctx context.Context, state string) error { 143 + key := authRequestKey(state) 144 + if err := r.client.Del(ctx, key).Err(); err != nil { 145 + return fmt.Errorf("failed to delete auth request: %w", err) 146 + } 147 + return nil 148 + }
+1 -1
internal/server/views/partials/header.templ
··· 39 39 class="absolute flex flex-col right-0 mt-2 p-1 gap-1 rounded w-48 bg-bg-light border border-bg-dark" 40 40 > 41 41 <a 42 - href={ templ.URL(fmt.Sprintf("/@%s", params.User.Handle)) } 42 + href={ templ.URL(fmt.Sprintf("/@%s", params.User.BskyProfile.Handle)) } 43 43 class="flex items-center px-4 py-2 text-sm hover:bg-bg gap-2" 44 44 > 45 45 <i class="w-4 h-4" data-lucide="user"></i>
+2 -3
internal/types/types.go
··· 1 1 package types 2 2 3 3 type OauthUser struct { 4 - Handle string 5 - Did string 6 - Pds string 4 + Did string 5 + Pds string 7 6 } 8 7 9 8 type User struct {