Monorepo for Tangled
at 88a45c13b8c4e4ae11608363ae74886874f18be2 255 lines 6.1 kB view raw
1package oauth 2 3import ( 4 "errors" 5 "fmt" 6 "net/http" 7 "time" 8 9 comatproto "github.com/bluesky-social/indigo/api/atproto" 10 "github.com/bluesky-social/indigo/atproto/auth/oauth" 11 atpclient "github.com/bluesky-social/indigo/atproto/client" 12 "github.com/bluesky-social/indigo/atproto/syntax" 13 xrpc "github.com/bluesky-social/indigo/xrpc" 14 "github.com/gorilla/sessions" 15 "github.com/lestrrat-go/jwx/v2/jwk" 16 "github.com/posthog/posthog-go" 17 "tangled.org/core/appview/config" 18) 19 20func New(config *config.Config, ph posthog.Client) (*OAuth, error) { 21 22 var oauthConfig oauth.ClientConfig 23 var clientUri string 24 25 if config.Core.Dev { 26 clientUri = "http://127.0.0.1:3000" 27 callbackUri := clientUri + "/oauth/callback" 28 oauthConfig = oauth.NewLocalhostConfig(callbackUri, []string{"atproto", "transition:generic"}) 29 } else { 30 clientUri = config.Core.AppviewHost 31 clientId := fmt.Sprintf("%s/oauth/client-metadata.json", clientUri) 32 callbackUri := clientUri + "/oauth/callback" 33 oauthConfig = oauth.NewPublicConfig(clientId, callbackUri, []string{"atproto", "transition:generic"}) 34 } 35 36 jwksUri := clientUri + "/oauth/jwks.json" 37 38 authStore, err := NewRedisStore(config.Redis.ToURL()) 39 if err != nil { 40 return nil, err 41 } 42 43 sessStore := sessions.NewCookieStore([]byte(config.Core.CookieSecret)) 44 45 return &OAuth{ 46 ClientApp: oauth.NewClientApp(&oauthConfig, authStore), 47 Config: config, 48 SessStore: sessStore, 49 JwksUri: jwksUri, 50 Posthog: ph, 51 }, nil 52} 53 54type OAuth struct { 55 ClientApp *oauth.ClientApp 56 SessStore *sessions.CookieStore 57 Config *config.Config 58 JwksUri string 59 Posthog posthog.Client 60} 61 62func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, sessData *oauth.ClientSessionData) error { 63 // first we save the did in the user session 64 userSession, err := o.SessStore.Get(r, SessionName) 65 if err != nil { 66 return err 67 } 68 69 userSession.Values[SessionDid] = sessData.AccountDID.String() 70 userSession.Values[SessionPds] = sessData.HostURL 71 userSession.Values[SessionId] = sessData.SessionID 72 userSession.Values[SessionAuthenticated] = true 73 return userSession.Save(r, w) 74} 75 76func (o *OAuth) ResumeSession(r *http.Request) (*oauth.ClientSession, error) { 77 userSession, err := o.SessStore.Get(r, SessionName) 78 if err != nil { 79 return nil, fmt.Errorf("error getting user session: %w", err) 80 } 81 if userSession.IsNew { 82 return nil, fmt.Errorf("no session available for user") 83 } 84 85 d := userSession.Values[SessionDid].(string) 86 sessDid, err := syntax.ParseDID(d) 87 if err != nil { 88 return nil, fmt.Errorf("malformed DID in session cookie '%s': %w", d, err) 89 } 90 91 sessId := userSession.Values[SessionId].(string) 92 93 clientSess, err := o.ClientApp.ResumeSession(r.Context(), sessDid, sessId) 94 if err != nil { 95 return nil, fmt.Errorf("failed to resume session: %w", err) 96 } 97 98 return clientSess, nil 99} 100 101func (o *OAuth) DeleteSession(w http.ResponseWriter, r *http.Request) error { 102 userSession, err := o.SessStore.Get(r, SessionName) 103 if err != nil { 104 return fmt.Errorf("error getting user session: %w", err) 105 } 106 if userSession.IsNew { 107 return fmt.Errorf("no session available for user") 108 } 109 110 d := userSession.Values[SessionDid].(string) 111 sessDid, err := syntax.ParseDID(d) 112 if err != nil { 113 return fmt.Errorf("malformed DID in session cookie '%s': %w", d, err) 114 } 115 116 sessId := userSession.Values[SessionId].(string) 117 118 // delete the session 119 err1 := o.ClientApp.Logout(r.Context(), sessDid, sessId) 120 121 // remove the cookie 122 userSession.Options.MaxAge = -1 123 err2 := o.SessStore.Save(r, w, userSession) 124 125 return errors.Join(err1, err2) 126} 127 128func pubKeyFromJwk(jwks string) (jwk.Key, error) { 129 k, err := jwk.ParseKey([]byte(jwks)) 130 if err != nil { 131 return nil, err 132 } 133 pubKey, err := k.PublicKey() 134 if err != nil { 135 return nil, err 136 } 137 return pubKey, nil 138} 139 140type User struct { 141 Did string 142 Pds string 143} 144 145func (o *OAuth) GetUser(r *http.Request) *User { 146 sess, err := o.SessStore.Get(r, SessionName) 147 148 if err != nil || sess.IsNew { 149 return nil 150 } 151 152 return &User{ 153 Did: sess.Values[SessionDid].(string), 154 Pds: sess.Values[SessionPds].(string), 155 } 156} 157 158func (o *OAuth) GetDid(r *http.Request) string { 159 if u := o.GetUser(r); u != nil { 160 return u.Did 161 } 162 163 return "" 164} 165 166func (o *OAuth) AuthorizedClient(r *http.Request) (*atpclient.APIClient, error) { 167 session, err := o.ResumeSession(r) 168 if err != nil { 169 return nil, fmt.Errorf("error getting session: %w", err) 170 } 171 return session.APIClient(), nil 172} 173 174// this is a higher level abstraction on ServerGetServiceAuth 175type ServiceClientOpts struct { 176 service string 177 exp int64 178 lxm string 179 dev bool 180} 181 182type ServiceClientOpt func(*ServiceClientOpts) 183 184func WithService(service string) ServiceClientOpt { 185 return func(s *ServiceClientOpts) { 186 s.service = service 187 } 188} 189 190// Specify the Duration in seconds for the expiry of this token 191// 192// The time of expiry is calculated as time.Now().Unix() + exp 193func WithExp(exp int64) ServiceClientOpt { 194 return func(s *ServiceClientOpts) { 195 s.exp = time.Now().Unix() + exp 196 } 197} 198 199func WithLxm(lxm string) ServiceClientOpt { 200 return func(s *ServiceClientOpts) { 201 s.lxm = lxm 202 } 203} 204 205func WithDev(dev bool) ServiceClientOpt { 206 return func(s *ServiceClientOpts) { 207 s.dev = dev 208 } 209} 210 211func (s *ServiceClientOpts) Audience() string { 212 return fmt.Sprintf("did:web:%s", s.service) 213} 214 215func (s *ServiceClientOpts) Host() string { 216 scheme := "https://" 217 if s.dev { 218 scheme = "http://" 219 } 220 221 return scheme + s.service 222} 223 224func (o *OAuth) ServiceClient(r *http.Request, os ...ServiceClientOpt) (*xrpc.Client, error) { 225 opts := ServiceClientOpts{} 226 for _, o := range os { 227 o(&opts) 228 } 229 230 client, err := o.AuthorizedClient(r) 231 if err != nil { 232 return nil, err 233 } 234 235 // force expiry to atleast 60 seconds in the future 236 sixty := time.Now().Unix() + 60 237 if opts.exp < sixty { 238 opts.exp = sixty 239 } 240 241 resp, err := comatproto.ServerGetServiceAuth(r.Context(), client, opts.Audience(), opts.exp, opts.lxm) 242 if err != nil { 243 return nil, err 244 } 245 246 return &xrpc.Client{ 247 Auth: &xrpc.AuthInfo{ 248 AccessJwt: resp.Token, 249 }, 250 Host: opts.Host(), 251 Client: &http.Client{ 252 Timeout: time.Second * 5, 253 }, 254 }, nil 255}