Monorepo for Tangled
at f887743001ec3f0f95f90c1a4d807cb37d9132d7 365 lines 9.3 kB view raw
1package oauth 2 3import ( 4 "errors" 5 "fmt" 6 "log/slog" 7 "net/http" 8 "sync" 9 "time" 10 11 comatproto "github.com/bluesky-social/indigo/api/atproto" 12 "github.com/bluesky-social/indigo/atproto/auth/oauth" 13 atpclient "github.com/bluesky-social/indigo/atproto/client" 14 atcrypto "github.com/bluesky-social/indigo/atproto/crypto" 15 "github.com/bluesky-social/indigo/atproto/syntax" 16 xrpc "github.com/bluesky-social/indigo/xrpc" 17 "github.com/gorilla/sessions" 18 "github.com/posthog/posthog-go" 19 "tangled.org/core/appview/config" 20 "tangled.org/core/appview/db" 21 "tangled.org/core/idresolver" 22 "tangled.org/core/rbac" 23) 24 25type OAuth struct { 26 ClientApp *oauth.ClientApp 27 SessStore *sessions.CookieStore 28 Config *config.Config 29 JwksUri string 30 ClientName string 31 ClientUri string 32 Posthog posthog.Client 33 Db *db.DB 34 Enforcer *rbac.Enforcer 35 IdResolver *idresolver.Resolver 36 Logger *slog.Logger 37 38 appPasswordSession *AppPasswordSession 39 appPasswordSessionMu sync.Mutex 40} 41 42func New(config *config.Config, ph posthog.Client, db *db.DB, enforcer *rbac.Enforcer, res *idresolver.Resolver, logger *slog.Logger) (*OAuth, error) { 43 var oauthConfig oauth.ClientConfig 44 var clientUri string 45 if config.Core.Dev { 46 clientUri = "http://127.0.0.1:3000" 47 callbackUri := clientUri + "/oauth/callback" 48 oauthConfig = oauth.NewLocalhostConfig(callbackUri, TangledScopes) 49 } else { 50 clientUri = "https://" + config.Core.AppviewHost 51 clientId := fmt.Sprintf("%s/oauth/client-metadata.json", clientUri) 52 callbackUri := clientUri + "/oauth/callback" 53 oauthConfig = oauth.NewPublicConfig(clientId, callbackUri, TangledScopes) 54 } 55 56 // configure client secret 57 priv, err := atcrypto.ParsePrivateMultibase(config.OAuth.ClientSecret) 58 if err != nil { 59 return nil, err 60 } 61 if err := oauthConfig.SetClientSecret(priv, config.OAuth.ClientKid); err != nil { 62 return nil, err 63 } 64 65 jwksUri := clientUri + "/oauth/jwks.json" 66 67 authStore, err := NewRedisStore(&RedisStoreConfig{ 68 RedisURL: config.Redis.ToURL(), 69 SessionExpiryDuration: time.Hour * 24 * 90, 70 SessionInactivityDuration: time.Hour * 24 * 14, 71 AuthRequestExpiryDuration: time.Minute * 30, 72 }) 73 if err != nil { 74 return nil, err 75 } 76 77 sessStore := sessions.NewCookieStore([]byte(config.Core.CookieSecret)) 78 79 clientApp := oauth.NewClientApp(&oauthConfig, authStore) 80 clientApp.Dir = res.Directory() 81 // allow non-public transports in dev mode 82 if config.Core.Dev { 83 clientApp.Resolver.Client.Transport = http.DefaultTransport 84 } 85 86 clientName := config.Core.AppviewName 87 88 logger.Info("oauth setup successfully", "IsConfidential", clientApp.Config.IsConfidential()) 89 return &OAuth{ 90 ClientApp: clientApp, 91 Config: config, 92 SessStore: sessStore, 93 JwksUri: jwksUri, 94 ClientName: clientName, 95 ClientUri: clientUri, 96 Posthog: ph, 97 Db: db, 98 Enforcer: enforcer, 99 IdResolver: res, 100 Logger: logger, 101 }, nil 102} 103 104func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, sessData *oauth.ClientSessionData) error { 105 userSession, err := o.SessStore.Get(r, SessionName) 106 if err != nil { 107 return err 108 } 109 110 userSession.Values[SessionDid] = sessData.AccountDID.String() 111 userSession.Values[SessionPds] = sessData.HostURL 112 userSession.Values[SessionId] = sessData.SessionID 113 userSession.Values[SessionAuthenticated] = true 114 115 if err := userSession.Save(r, w); err != nil { 116 return err 117 } 118 119 handle := "" 120 resolved, err := o.IdResolver.ResolveIdent(r.Context(), sessData.AccountDID.String()) 121 if err == nil && resolved.Handle.String() != "" { 122 handle = resolved.Handle.String() 123 } 124 125 registry := o.GetAccounts(r) 126 if err := registry.AddAccount(sessData.AccountDID.String(), handle, sessData.SessionID); err != nil { 127 return err 128 } 129 return o.SaveAccounts(w, r, registry) 130} 131 132func (o *OAuth) ResumeSession(r *http.Request) (*oauth.ClientSession, error) { 133 userSession, err := o.SessStore.Get(r, SessionName) 134 if err != nil { 135 return nil, fmt.Errorf("error getting user session: %w", err) 136 } 137 if userSession.IsNew { 138 return nil, fmt.Errorf("no session available for user") 139 } 140 141 d := userSession.Values[SessionDid].(string) 142 sessDid, err := syntax.ParseDID(d) 143 if err != nil { 144 return nil, fmt.Errorf("malformed DID in session cookie '%s': %w", d, err) 145 } 146 147 sessId := userSession.Values[SessionId].(string) 148 149 clientSess, err := o.ClientApp.ResumeSession(r.Context(), sessDid, sessId) 150 if err != nil { 151 return nil, fmt.Errorf("failed to resume session: %w", err) 152 } 153 154 return clientSess, nil 155} 156 157func (o *OAuth) DeleteSession(w http.ResponseWriter, r *http.Request) error { 158 userSession, err := o.SessStore.Get(r, SessionName) 159 if err != nil { 160 return fmt.Errorf("error getting user session: %w", err) 161 } 162 if userSession.IsNew { 163 return fmt.Errorf("no session available for user") 164 } 165 166 d := userSession.Values[SessionDid].(string) 167 sessDid, err := syntax.ParseDID(d) 168 if err != nil { 169 return fmt.Errorf("malformed DID in session cookie '%s': %w", d, err) 170 } 171 172 sessId := userSession.Values[SessionId].(string) 173 174 // delete the session 175 err1 := o.ClientApp.Logout(r.Context(), sessDid, sessId) 176 if err1 != nil { 177 err1 = fmt.Errorf("failed to logout: %w", err1) 178 } 179 180 // remove the cookie 181 userSession.Options.MaxAge = -1 182 err2 := o.SessStore.Save(r, w, userSession) 183 if err2 != nil { 184 err2 = fmt.Errorf("failed to save into session store: %w", err2) 185 } 186 187 return errors.Join(err1, err2) 188} 189 190func (o *OAuth) SwitchAccount(w http.ResponseWriter, r *http.Request, targetDid string) error { 191 registry := o.GetAccounts(r) 192 account := registry.FindAccount(targetDid) 193 if account == nil { 194 return fmt.Errorf("account not found in registry: %s", targetDid) 195 } 196 197 did, err := syntax.ParseDID(targetDid) 198 if err != nil { 199 return fmt.Errorf("invalid DID: %w", err) 200 } 201 202 sess, err := o.ClientApp.ResumeSession(r.Context(), did, account.SessionId) 203 if err != nil { 204 registry.RemoveAccount(targetDid) 205 _ = o.SaveAccounts(w, r, registry) 206 return fmt.Errorf("session expired for account: %w", err) 207 } 208 209 userSession, err := o.SessStore.Get(r, SessionName) 210 if err != nil { 211 return err 212 } 213 214 userSession.Values[SessionDid] = sess.Data.AccountDID.String() 215 userSession.Values[SessionPds] = sess.Data.HostURL 216 userSession.Values[SessionId] = sess.Data.SessionID 217 userSession.Values[SessionAuthenticated] = true 218 219 return userSession.Save(r, w) 220} 221 222func (o *OAuth) RemoveAccount(w http.ResponseWriter, r *http.Request, targetDid string) error { 223 registry := o.GetAccounts(r) 224 account := registry.FindAccount(targetDid) 225 if account == nil { 226 return nil 227 } 228 229 did, err := syntax.ParseDID(targetDid) 230 if err == nil { 231 _ = o.ClientApp.Logout(r.Context(), did, account.SessionId) 232 } 233 234 registry.RemoveAccount(targetDid) 235 return o.SaveAccounts(w, r, registry) 236} 237 238type User struct { 239 Did string 240 Pds string 241} 242 243func (o *OAuth) GetUser(r *http.Request) *User { 244 sess, err := o.ResumeSession(r) 245 if err != nil { 246 return nil 247 } 248 249 return &User{ 250 Did: sess.Data.AccountDID.String(), 251 Pds: sess.Data.HostURL, 252 } 253} 254 255func (o *OAuth) GetDid(r *http.Request) string { 256 if u := o.GetMultiAccountUser(r); u != nil { 257 return u.Did() 258 } 259 260 return "" 261} 262 263func (o *OAuth) AuthorizedClient(r *http.Request) (*atpclient.APIClient, error) { 264 session, err := o.ResumeSession(r) 265 if err != nil { 266 return nil, fmt.Errorf("error getting session: %w", err) 267 } 268 return session.APIClient(), nil 269} 270 271// this is a higher level abstraction on ServerGetServiceAuth 272type ServiceClientOpts struct { 273 service string 274 exp int64 275 lxm string 276 dev bool 277 timeout time.Duration 278} 279 280type ServiceClientOpt func(*ServiceClientOpts) 281 282func DefaultServiceClientOpts() ServiceClientOpts { 283 return ServiceClientOpts{ 284 timeout: time.Second * 5, 285 } 286} 287 288func WithService(service string) ServiceClientOpt { 289 return func(s *ServiceClientOpts) { 290 s.service = service 291 } 292} 293 294// Specify the Duration in seconds for the expiry of this token 295// 296// The time of expiry is calculated as time.Now().Unix() + exp 297func WithExp(exp int64) ServiceClientOpt { 298 return func(s *ServiceClientOpts) { 299 s.exp = time.Now().Unix() + exp 300 } 301} 302 303func WithLxm(lxm string) ServiceClientOpt { 304 return func(s *ServiceClientOpts) { 305 s.lxm = lxm 306 } 307} 308 309func WithDev(dev bool) ServiceClientOpt { 310 return func(s *ServiceClientOpts) { 311 s.dev = dev 312 } 313} 314 315func WithTimeout(timeout time.Duration) ServiceClientOpt { 316 return func(s *ServiceClientOpts) { 317 s.timeout = timeout 318 } 319} 320 321func (s *ServiceClientOpts) Audience() string { 322 return fmt.Sprintf("did:web:%s", s.service) 323} 324 325func (s *ServiceClientOpts) Host() string { 326 scheme := "https://" 327 if s.dev { 328 scheme = "http://" 329 } 330 331 return scheme + s.service 332} 333 334func (o *OAuth) ServiceClient(r *http.Request, os ...ServiceClientOpt) (*xrpc.Client, error) { 335 opts := DefaultServiceClientOpts() 336 for _, o := range os { 337 o(&opts) 338 } 339 340 client, err := o.AuthorizedClient(r) 341 if err != nil { 342 return nil, err 343 } 344 345 // force expiry to atleast 60 seconds in the future 346 sixty := time.Now().Unix() + 60 347 if opts.exp < sixty { 348 opts.exp = sixty 349 } 350 351 resp, err := comatproto.ServerGetServiceAuth(r.Context(), client, opts.Audience(), opts.exp, opts.lxm) 352 if err != nil { 353 return nil, err 354 } 355 356 return &xrpc.Client{ 357 Auth: &xrpc.AuthInfo{ 358 AccessJwt: resp.Token, 359 }, 360 Host: opts.Host(), 361 Client: &http.Client{ 362 Timeout: opts.timeout, 363 }, 364 }, nil 365}