Monorepo for Tangled
at 696ec39733bcacc8f77bdb64d12217a23eb19234 361 lines 9.2 kB view raw
1package oauth 2 3import ( 4 "errors" 5 "fmt" 6 "log/slog" 7 "net/http" 8 "time" 9 10 comatproto "github.com/bluesky-social/indigo/api/atproto" 11 "github.com/bluesky-social/indigo/atproto/auth/oauth" 12 atpclient "github.com/bluesky-social/indigo/atproto/client" 13 atcrypto "github.com/bluesky-social/indigo/atproto/crypto" 14 "github.com/bluesky-social/indigo/atproto/syntax" 15 xrpc "github.com/bluesky-social/indigo/xrpc" 16 "github.com/gorilla/sessions" 17 "github.com/posthog/posthog-go" 18 "tangled.org/core/appview/config" 19 "tangled.org/core/appview/db" 20 "tangled.org/core/idresolver" 21 "tangled.org/core/rbac" 22) 23 24type OAuth struct { 25 ClientApp *oauth.ClientApp 26 SessStore *sessions.CookieStore 27 Config *config.Config 28 JwksUri string 29 ClientName string 30 ClientUri string 31 Posthog posthog.Client 32 Db *db.DB 33 Enforcer *rbac.Enforcer 34 IdResolver *idresolver.Resolver 35 Logger *slog.Logger 36} 37 38func New(config *config.Config, ph posthog.Client, db *db.DB, enforcer *rbac.Enforcer, res *idresolver.Resolver, logger *slog.Logger) (*OAuth, error) { 39 var oauthConfig oauth.ClientConfig 40 var clientUri string 41 if config.Core.Dev { 42 clientUri = "http://127.0.0.1:3000" 43 callbackUri := clientUri + "/oauth/callback" 44 oauthConfig = oauth.NewLocalhostConfig(callbackUri, TangledScopes) 45 } else { 46 clientUri = config.Core.AppviewHost 47 clientId := fmt.Sprintf("%s/oauth/client-metadata.json", clientUri) 48 callbackUri := clientUri + "/oauth/callback" 49 oauthConfig = oauth.NewPublicConfig(clientId, callbackUri, TangledScopes) 50 } 51 52 // configure client secret 53 priv, err := atcrypto.ParsePrivateMultibase(config.OAuth.ClientSecret) 54 if err != nil { 55 return nil, err 56 } 57 if err := oauthConfig.SetClientSecret(priv, config.OAuth.ClientKid); err != nil { 58 return nil, err 59 } 60 61 jwksUri := clientUri + "/oauth/jwks.json" 62 63 authStore, err := NewRedisStore(&RedisStoreConfig{ 64 RedisURL: config.Redis.ToURL(), 65 SessionExpiryDuration: time.Hour * 24 * 90, 66 SessionInactivityDuration: time.Hour * 24 * 14, 67 AuthRequestExpiryDuration: time.Minute * 30, 68 }) 69 if err != nil { 70 return nil, err 71 } 72 73 sessStore := sessions.NewCookieStore([]byte(config.Core.CookieSecret)) 74 75 clientApp := oauth.NewClientApp(&oauthConfig, authStore) 76 clientApp.Dir = res.Directory() 77 // allow non-public transports in dev mode 78 if config.Core.Dev { 79 clientApp.Resolver.Client.Transport = http.DefaultTransport 80 } 81 82 clientName := config.Core.AppviewName 83 84 logger.Info("oauth setup successfully", "IsConfidential", clientApp.Config.IsConfidential()) 85 return &OAuth{ 86 ClientApp: clientApp, 87 Config: config, 88 SessStore: sessStore, 89 JwksUri: jwksUri, 90 ClientName: clientName, 91 ClientUri: clientUri, 92 Posthog: ph, 93 Db: db, 94 Enforcer: enforcer, 95 IdResolver: res, 96 Logger: logger, 97 }, nil 98} 99 100func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, sessData *oauth.ClientSessionData) error { 101 userSession, err := o.SessStore.Get(r, SessionName) 102 if err != nil { 103 return err 104 } 105 106 userSession.Values[SessionDid] = sessData.AccountDID.String() 107 userSession.Values[SessionPds] = sessData.HostURL 108 userSession.Values[SessionId] = sessData.SessionID 109 userSession.Values[SessionAuthenticated] = true 110 111 if err := userSession.Save(r, w); err != nil { 112 return err 113 } 114 115 handle := "" 116 resolved, err := o.IdResolver.ResolveIdent(r.Context(), sessData.AccountDID.String()) 117 if err == nil && resolved.Handle.String() != "" { 118 handle = resolved.Handle.String() 119 } 120 121 registry := o.GetAccounts(r) 122 if err := registry.AddAccount(sessData.AccountDID.String(), handle, sessData.SessionID); err != nil { 123 return err 124 } 125 return o.SaveAccounts(w, r, registry) 126} 127 128func (o *OAuth) ResumeSession(r *http.Request) (*oauth.ClientSession, error) { 129 userSession, err := o.SessStore.Get(r, SessionName) 130 if err != nil { 131 return nil, fmt.Errorf("error getting user session: %w", err) 132 } 133 if userSession.IsNew { 134 return nil, fmt.Errorf("no session available for user") 135 } 136 137 d := userSession.Values[SessionDid].(string) 138 sessDid, err := syntax.ParseDID(d) 139 if err != nil { 140 return nil, fmt.Errorf("malformed DID in session cookie '%s': %w", d, err) 141 } 142 143 sessId := userSession.Values[SessionId].(string) 144 145 clientSess, err := o.ClientApp.ResumeSession(r.Context(), sessDid, sessId) 146 if err != nil { 147 return nil, fmt.Errorf("failed to resume session: %w", err) 148 } 149 150 return clientSess, nil 151} 152 153func (o *OAuth) DeleteSession(w http.ResponseWriter, r *http.Request) error { 154 userSession, err := o.SessStore.Get(r, SessionName) 155 if err != nil { 156 return fmt.Errorf("error getting user session: %w", err) 157 } 158 if userSession.IsNew { 159 return fmt.Errorf("no session available for user") 160 } 161 162 d := userSession.Values[SessionDid].(string) 163 sessDid, err := syntax.ParseDID(d) 164 if err != nil { 165 return fmt.Errorf("malformed DID in session cookie '%s': %w", d, err) 166 } 167 168 sessId := userSession.Values[SessionId].(string) 169 170 // delete the session 171 err1 := o.ClientApp.Logout(r.Context(), sessDid, sessId) 172 if err1 != nil { 173 err1 = fmt.Errorf("failed to logout: %w", err1) 174 } 175 176 // remove the cookie 177 userSession.Options.MaxAge = -1 178 err2 := o.SessStore.Save(r, w, userSession) 179 if err2 != nil { 180 err2 = fmt.Errorf("failed to save into session store: %w", err2) 181 } 182 183 return errors.Join(err1, err2) 184} 185 186func (o *OAuth) SwitchAccount(w http.ResponseWriter, r *http.Request, targetDid string) error { 187 registry := o.GetAccounts(r) 188 account := registry.FindAccount(targetDid) 189 if account == nil { 190 return fmt.Errorf("account not found in registry: %s", targetDid) 191 } 192 193 did, err := syntax.ParseDID(targetDid) 194 if err != nil { 195 return fmt.Errorf("invalid DID: %w", err) 196 } 197 198 sess, err := o.ClientApp.ResumeSession(r.Context(), did, account.SessionId) 199 if err != nil { 200 registry.RemoveAccount(targetDid) 201 _ = o.SaveAccounts(w, r, registry) 202 return fmt.Errorf("session expired for account: %w", err) 203 } 204 205 userSession, err := o.SessStore.Get(r, SessionName) 206 if err != nil { 207 return err 208 } 209 210 userSession.Values[SessionDid] = sess.Data.AccountDID.String() 211 userSession.Values[SessionPds] = sess.Data.HostURL 212 userSession.Values[SessionId] = sess.Data.SessionID 213 userSession.Values[SessionAuthenticated] = true 214 215 return userSession.Save(r, w) 216} 217 218func (o *OAuth) RemoveAccount(w http.ResponseWriter, r *http.Request, targetDid string) error { 219 registry := o.GetAccounts(r) 220 account := registry.FindAccount(targetDid) 221 if account == nil { 222 return nil 223 } 224 225 did, err := syntax.ParseDID(targetDid) 226 if err == nil { 227 _ = o.ClientApp.Logout(r.Context(), did, account.SessionId) 228 } 229 230 registry.RemoveAccount(targetDid) 231 return o.SaveAccounts(w, r, registry) 232} 233 234type User struct { 235 Did string 236 Pds string 237} 238 239func (o *OAuth) GetUser(r *http.Request) *User { 240 sess, err := o.ResumeSession(r) 241 if err != nil { 242 return nil 243 } 244 245 return &User{ 246 Did: sess.Data.AccountDID.String(), 247 Pds: sess.Data.HostURL, 248 } 249} 250 251func (o *OAuth) GetDid(r *http.Request) string { 252 if u := o.GetMultiAccountUser(r); u != nil { 253 return u.Did() 254 } 255 256 return "" 257} 258 259func (o *OAuth) AuthorizedClient(r *http.Request) (*atpclient.APIClient, error) { 260 session, err := o.ResumeSession(r) 261 if err != nil { 262 return nil, fmt.Errorf("error getting session: %w", err) 263 } 264 return session.APIClient(), nil 265} 266 267// this is a higher level abstraction on ServerGetServiceAuth 268type ServiceClientOpts struct { 269 service string 270 exp int64 271 lxm string 272 dev bool 273 timeout time.Duration 274} 275 276type ServiceClientOpt func(*ServiceClientOpts) 277 278func DefaultServiceClientOpts() ServiceClientOpts { 279 return ServiceClientOpts{ 280 timeout: time.Second * 5, 281 } 282} 283 284func WithService(service string) ServiceClientOpt { 285 return func(s *ServiceClientOpts) { 286 s.service = service 287 } 288} 289 290// Specify the Duration in seconds for the expiry of this token 291// 292// The time of expiry is calculated as time.Now().Unix() + exp 293func WithExp(exp int64) ServiceClientOpt { 294 return func(s *ServiceClientOpts) { 295 s.exp = time.Now().Unix() + exp 296 } 297} 298 299func WithLxm(lxm string) ServiceClientOpt { 300 return func(s *ServiceClientOpts) { 301 s.lxm = lxm 302 } 303} 304 305func WithDev(dev bool) ServiceClientOpt { 306 return func(s *ServiceClientOpts) { 307 s.dev = dev 308 } 309} 310 311func WithTimeout(timeout time.Duration) ServiceClientOpt { 312 return func(s *ServiceClientOpts) { 313 s.timeout = timeout 314 } 315} 316 317func (s *ServiceClientOpts) Audience() string { 318 return fmt.Sprintf("did:web:%s", s.service) 319} 320 321func (s *ServiceClientOpts) Host() string { 322 scheme := "https://" 323 if s.dev { 324 scheme = "http://" 325 } 326 327 return scheme + s.service 328} 329 330func (o *OAuth) ServiceClient(r *http.Request, os ...ServiceClientOpt) (*xrpc.Client, error) { 331 opts := DefaultServiceClientOpts() 332 for _, o := range os { 333 o(&opts) 334 } 335 336 client, err := o.AuthorizedClient(r) 337 if err != nil { 338 return nil, err 339 } 340 341 // force expiry to atleast 60 seconds in the future 342 sixty := time.Now().Unix() + 60 343 if opts.exp < sixty { 344 opts.exp = sixty 345 } 346 347 resp, err := comatproto.ServerGetServiceAuth(r.Context(), client, opts.Audience(), opts.exp, opts.lxm) 348 if err != nil { 349 return nil, err 350 } 351 352 return &xrpc.Client{ 353 Auth: &xrpc.AuthInfo{ 354 AccessJwt: resp.Token, 355 }, 356 Host: opts.Host(), 357 Client: &http.Client{ 358 Timeout: opts.timeout, 359 }, 360 }, nil 361}