this repo has no description
1package oauth 2 3import ( 4 "errors" 5 "fmt" 6 "log/slog" 7 "net/http" 8 "time" 9 10 comatproto "github.com/bluesky-social/indigo/api/atproto" 11 "github.com/bluesky-social/indigo/atproto/auth/oauth" 12 atpclient "github.com/bluesky-social/indigo/atproto/client" 13 atcrypto "github.com/bluesky-social/indigo/atproto/crypto" 14 "github.com/bluesky-social/indigo/atproto/syntax" 15 xrpc "github.com/bluesky-social/indigo/xrpc" 16 "github.com/gorilla/sessions" 17 "github.com/posthog/posthog-go" 18 "tangled.org/core/appview/config" 19 "tangled.org/core/appview/db" 20 "tangled.org/core/idresolver" 21 "tangled.org/core/rbac" 22) 23 24type OAuth struct { 25 ClientApp *oauth.ClientApp 26 SessStore *sessions.CookieStore 27 Config *config.Config 28 JwksUri string 29 ClientName string 30 ClientUri string 31 Posthog posthog.Client 32 Db *db.DB 33 Enforcer *rbac.Enforcer 34 IdResolver *idresolver.Resolver 35 Logger *slog.Logger 36} 37 38func New(config *config.Config, ph posthog.Client, db *db.DB, enforcer *rbac.Enforcer, res *idresolver.Resolver, logger *slog.Logger) (*OAuth, error) { 39 var oauthConfig oauth.ClientConfig 40 var clientUri string 41 if config.Core.Dev { 42 clientUri = "http://127.0.0.1:3000" 43 callbackUri := clientUri + "/oauth/callback" 44 oauthConfig = oauth.NewLocalhostConfig(callbackUri, []string{"atproto", "transition:generic"}) 45 } else { 46 clientUri = config.Core.AppviewHost 47 clientId := fmt.Sprintf("%s/oauth/client-metadata.json", clientUri) 48 callbackUri := clientUri + "/oauth/callback" 49 oauthConfig = oauth.NewPublicConfig(clientId, callbackUri, []string{"atproto", "transition:generic"}) 50 } 51 52 // configure client secret 53 priv, err := atcrypto.ParsePrivateMultibase(config.OAuth.ClientSecret) 54 if err != nil { 55 return nil, err 56 } 57 if err := oauthConfig.SetClientSecret(priv, config.OAuth.ClientKid); err != nil { 58 return nil, err 59 } 60 61 jwksUri := clientUri + "/oauth/jwks.json" 62 63 authStore, err := NewRedisStore(config.Redis.ToURL()) 64 if err != nil { 65 return nil, err 66 } 67 68 sessStore := sessions.NewCookieStore([]byte(config.Core.CookieSecret)) 69 70 clientApp := oauth.NewClientApp(&oauthConfig, authStore) 71 clientApp.Dir = res.Directory() 72 73 clientName := config.Core.AppviewName 74 75 return &OAuth{ 76 ClientApp: clientApp, 77 Config: config, 78 SessStore: sessStore, 79 JwksUri: jwksUri, 80 ClientName: clientName, 81 ClientUri: clientUri, 82 Posthog: ph, 83 Db: db, 84 Enforcer: enforcer, 85 IdResolver: res, 86 Logger: logger, 87 }, nil 88} 89 90func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, sessData *oauth.ClientSessionData) error { 91 // first we save the did in the user session 92 userSession, err := o.SessStore.Get(r, SessionName) 93 if err != nil { 94 return err 95 } 96 97 userSession.Values[SessionDid] = sessData.AccountDID.String() 98 userSession.Values[SessionPds] = sessData.HostURL 99 userSession.Values[SessionId] = sessData.SessionID 100 userSession.Values[SessionAuthenticated] = true 101 return userSession.Save(r, w) 102} 103 104func (o *OAuth) ResumeSession(r *http.Request) (*oauth.ClientSession, error) { 105 userSession, err := o.SessStore.Get(r, SessionName) 106 if err != nil { 107 return nil, fmt.Errorf("error getting user session: %w", err) 108 } 109 if userSession.IsNew { 110 return nil, fmt.Errorf("no session available for user") 111 } 112 113 d := userSession.Values[SessionDid].(string) 114 sessDid, err := syntax.ParseDID(d) 115 if err != nil { 116 return nil, fmt.Errorf("malformed DID in session cookie '%s': %w", d, err) 117 } 118 119 sessId := userSession.Values[SessionId].(string) 120 121 clientSess, err := o.ClientApp.ResumeSession(r.Context(), sessDid, sessId) 122 if err != nil { 123 return nil, fmt.Errorf("failed to resume session: %w", err) 124 } 125 126 return clientSess, nil 127} 128 129func (o *OAuth) DeleteSession(w http.ResponseWriter, r *http.Request) error { 130 userSession, err := o.SessStore.Get(r, SessionName) 131 if err != nil { 132 return fmt.Errorf("error getting user session: %w", err) 133 } 134 if userSession.IsNew { 135 return fmt.Errorf("no session available for user") 136 } 137 138 d := userSession.Values[SessionDid].(string) 139 sessDid, err := syntax.ParseDID(d) 140 if err != nil { 141 return fmt.Errorf("malformed DID in session cookie '%s': %w", d, err) 142 } 143 144 sessId := userSession.Values[SessionId].(string) 145 146 // delete the session 147 err1 := o.ClientApp.Logout(r.Context(), sessDid, sessId) 148 149 // remove the cookie 150 userSession.Options.MaxAge = -1 151 err2 := o.SessStore.Save(r, w, userSession) 152 153 return errors.Join(err1, err2) 154} 155 156type User struct { 157 Did string 158 Pds string 159} 160 161func (o *OAuth) GetUser(r *http.Request) *User { 162 sess, err := o.SessStore.Get(r, SessionName) 163 164 if err != nil || sess.IsNew { 165 return nil 166 } 167 168 return &User{ 169 Did: sess.Values[SessionDid].(string), 170 Pds: sess.Values[SessionPds].(string), 171 } 172} 173 174func (o *OAuth) GetDid(r *http.Request) string { 175 if u := o.GetUser(r); u != nil { 176 return u.Did 177 } 178 179 return "" 180} 181 182func (o *OAuth) AuthorizedClient(r *http.Request) (*atpclient.APIClient, error) { 183 session, err := o.ResumeSession(r) 184 if err != nil { 185 return nil, fmt.Errorf("error getting session: %w", err) 186 } 187 return session.APIClient(), nil 188} 189 190// this is a higher level abstraction on ServerGetServiceAuth 191type ServiceClientOpts struct { 192 service string 193 exp int64 194 lxm string 195 dev bool 196} 197 198type ServiceClientOpt func(*ServiceClientOpts) 199 200func WithService(service string) ServiceClientOpt { 201 return func(s *ServiceClientOpts) { 202 s.service = service 203 } 204} 205 206// Specify the Duration in seconds for the expiry of this token 207// 208// The time of expiry is calculated as time.Now().Unix() + exp 209func WithExp(exp int64) ServiceClientOpt { 210 return func(s *ServiceClientOpts) { 211 s.exp = time.Now().Unix() + exp 212 } 213} 214 215func WithLxm(lxm string) ServiceClientOpt { 216 return func(s *ServiceClientOpts) { 217 s.lxm = lxm 218 } 219} 220 221func WithDev(dev bool) ServiceClientOpt { 222 return func(s *ServiceClientOpts) { 223 s.dev = dev 224 } 225} 226 227func (s *ServiceClientOpts) Audience() string { 228 return fmt.Sprintf("did:web:%s", s.service) 229} 230 231func (s *ServiceClientOpts) Host() string { 232 scheme := "https://" 233 if s.dev { 234 scheme = "http://" 235 } 236 237 return scheme + s.service 238} 239 240func (o *OAuth) ServiceClient(r *http.Request, os ...ServiceClientOpt) (*xrpc.Client, error) { 241 opts := ServiceClientOpts{} 242 for _, o := range os { 243 o(&opts) 244 } 245 246 client, err := o.AuthorizedClient(r) 247 if err != nil { 248 return nil, err 249 } 250 251 // force expiry to atleast 60 seconds in the future 252 sixty := time.Now().Unix() + 60 253 if opts.exp < sixty { 254 opts.exp = sixty 255 } 256 257 resp, err := comatproto.ServerGetServiceAuth(r.Context(), client, opts.Audience(), opts.exp, opts.lxm) 258 if err != nil { 259 return nil, err 260 } 261 262 return &xrpc.Client{ 263 Auth: &xrpc.AuthInfo{ 264 AccessJwt: resp.Token, 265 }, 266 Host: opts.Host(), 267 Client: &http.Client{ 268 Timeout: time.Second * 5, 269 }, 270 }, nil 271}