this repo has no description
1package oauth 2 3import ( 4 "fmt" 5 "log" 6 "net/http" 7 "time" 8 9 "github.com/gorilla/sessions" 10 oauth "github.com/haileyok/atproto-oauth-golang" 11 "github.com/haileyok/atproto-oauth-golang/helpers" 12 "tangled.sh/tangled.sh/core/appview" 13 "tangled.sh/tangled.sh/core/appview/db" 14 "tangled.sh/tangled.sh/core/appview/oauth/client" 15 xrpc "tangled.sh/tangled.sh/core/appview/xrpcclient" 16) 17 18type OAuthRequest struct { 19 ID uint 20 AuthserverIss string 21 State string 22 Did string 23 PdsUrl string 24 PkceVerifier string 25 DpopAuthserverNonce string 26 DpopPrivateJwk string 27} 28 29type OAuth struct { 30 Store *sessions.CookieStore 31 Db *db.DB 32 Config *appview.Config 33} 34 35func NewOAuth(db *db.DB, config *appview.Config) *OAuth { 36 return &OAuth{ 37 Store: sessions.NewCookieStore([]byte(config.Core.CookieSecret)), 38 Db: db, 39 Config: config, 40 } 41} 42 43func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, oreq db.OAuthRequest, oresp *oauth.TokenResponse) error { 44 // first we save the did in the user session 45 userSession, err := o.Store.Get(r, appview.SessionName) 46 if err != nil { 47 return err 48 } 49 50 userSession.Values[appview.SessionDid] = oreq.Did 51 userSession.Values[appview.SessionAuthenticated] = true 52 err = userSession.Save(r, w) 53 if err != nil { 54 return fmt.Errorf("error saving user session: %w", err) 55 } 56 57 // then save the whole thing in the db 58 session := db.OAuthSession{ 59 Did: oreq.Did, 60 Handle: oreq.Handle, 61 PdsUrl: oreq.PdsUrl, 62 DpopAuthserverNonce: oreq.DpopAuthserverNonce, 63 AuthServerIss: oreq.AuthserverIss, 64 DpopPrivateJwk: oreq.DpopPrivateJwk, 65 AccessJwt: oresp.AccessToken, 66 RefreshJwt: oresp.RefreshToken, 67 Expiry: time.Now().Add(time.Duration(oresp.ExpiresIn) * time.Second).Format(time.RFC3339), 68 } 69 70 return db.SaveOAuthSession(o.Db, session) 71} 72 73func (o *OAuth) ClearSession(r *http.Request, w http.ResponseWriter) error { 74 userSession, err := o.Store.Get(r, appview.SessionName) 75 if err != nil || userSession.IsNew { 76 return fmt.Errorf("error getting user session (or new session?): %w", err) 77 } 78 79 did := userSession.Values[appview.SessionDid].(string) 80 81 err = db.DeleteOAuthSessionByDid(o.Db, did) 82 if err != nil { 83 return fmt.Errorf("error deleting oauth session: %w", err) 84 } 85 86 userSession.Options.MaxAge = -1 87 88 return userSession.Save(r, w) 89} 90 91func (o *OAuth) GetSession(r *http.Request) (*db.OAuthSession, bool, error) { 92 userSession, err := o.Store.Get(r, appview.SessionName) 93 if err != nil || userSession.IsNew { 94 return nil, false, fmt.Errorf("error getting user session (or new session?): %w", err) 95 } 96 97 did := userSession.Values[appview.SessionDid].(string) 98 auth := userSession.Values[appview.SessionAuthenticated].(bool) 99 100 session, err := db.GetOAuthSessionByDid(o.Db, did) 101 if err != nil { 102 return nil, false, fmt.Errorf("error getting oauth session: %w", err) 103 } 104 105 expiry, err := time.Parse(time.RFC3339, session.Expiry) 106 if err != nil { 107 return nil, false, fmt.Errorf("error parsing expiry time: %w", err) 108 } 109 if expiry.Sub(time.Now()) <= 5*time.Minute { 110 privateJwk, err := helpers.ParseJWKFromBytes([]byte(session.DpopPrivateJwk)) 111 if err != nil { 112 return nil, false, err 113 } 114 oauthClient, err := client.NewClient(o.Config.OAuth.ServerMetadataUrl, 115 o.Config.OAuth.Jwks, 116 fmt.Sprintf("%s/oauth/callback", o.Config.Core.AppviewHost)) 117 118 if err != nil { 119 return nil, false, err 120 } 121 122 resp, err := oauthClient.RefreshTokenRequest(r.Context(), session.RefreshJwt, session.AuthServerIss, session.DpopAuthserverNonce, privateJwk) 123 if err != nil { 124 return nil, false, err 125 } 126 127 newExpiry := time.Now().Add(time.Duration(resp.ExpiresIn) * time.Second).Format(time.RFC3339) 128 err = db.RefreshOAuthSession(o.Db, did, resp.AccessToken, resp.RefreshToken, newExpiry) 129 if err != nil { 130 return nil, false, fmt.Errorf("error refreshing oauth session: %w", err) 131 } 132 133 // update the current session 134 session.AccessJwt = resp.AccessToken 135 session.RefreshJwt = resp.RefreshToken 136 session.DpopAuthserverNonce = resp.DpopAuthserverNonce 137 session.Expiry = newExpiry 138 } 139 140 return session, auth, nil 141} 142 143type User struct { 144 Handle string 145 Did string 146 Pds string 147} 148 149func (a *OAuth) GetUser(r *http.Request) *User { 150 clientSession, err := a.Store.Get(r, appview.SessionName) 151 152 if err != nil || clientSession.IsNew { 153 return nil 154 } 155 156 return &User{ 157 Handle: clientSession.Values[appview.SessionHandle].(string), 158 Did: clientSession.Values[appview.SessionDid].(string), 159 Pds: clientSession.Values[appview.SessionPds].(string), 160 } 161} 162 163func (a *OAuth) GetDid(r *http.Request) string { 164 clientSession, err := a.Store.Get(r, appview.SessionName) 165 166 if err != nil || clientSession.IsNew { 167 return "" 168 } 169 170 return clientSession.Values[appview.SessionDid].(string) 171} 172 173func (o *OAuth) AuthorizedClient(r *http.Request) (*xrpc.Client, error) { 174 session, auth, err := o.GetSession(r) 175 if err != nil { 176 return nil, fmt.Errorf("error getting session: %w", err) 177 } 178 if !auth { 179 return nil, fmt.Errorf("not authorized") 180 } 181 182 client := &oauth.XrpcClient{ 183 OnDpopPdsNonceChanged: func(did, newNonce string) { 184 err := db.UpdateDpopPdsNonce(o.Db, did, newNonce) 185 if err != nil { 186 log.Printf("error updating dpop pds nonce: %v", err) 187 } 188 }, 189 } 190 191 privateJwk, err := helpers.ParseJWKFromBytes([]byte(session.DpopPrivateJwk)) 192 if err != nil { 193 return nil, fmt.Errorf("error parsing private jwk: %w", err) 194 } 195 196 xrpcClient := xrpc.NewClient(client, &oauth.XrpcAuthedRequestArgs{ 197 Did: session.Did, 198 PdsUrl: session.PdsUrl, 199 DpopPdsNonce: session.PdsUrl, 200 AccessToken: session.AccessJwt, 201 Issuer: session.AuthServerIss, 202 DpopPrivateJwk: privateJwk, 203 }) 204 205 return xrpcClient, nil 206}