Monorepo for Tangled
1package oauth
2
3import (
4 "errors"
5 "fmt"
6 "log/slog"
7 "net/http"
8 "sync"
9 "time"
10
11 comatproto "github.com/bluesky-social/indigo/api/atproto"
12 "github.com/bluesky-social/indigo/atproto/auth/oauth"
13 atpclient "github.com/bluesky-social/indigo/atproto/client"
14 atcrypto "github.com/bluesky-social/indigo/atproto/crypto"
15 "github.com/bluesky-social/indigo/atproto/syntax"
16 xrpc "github.com/bluesky-social/indigo/xrpc"
17 "github.com/gorilla/sessions"
18 "github.com/posthog/posthog-go"
19 "tangled.org/core/appview/config"
20 "tangled.org/core/appview/db"
21 "tangled.org/core/idresolver"
22 "tangled.org/core/rbac"
23)
24
25type OAuth struct {
26 ClientApp *oauth.ClientApp
27 SessStore *sessions.CookieStore
28 Config *config.Config
29 JwksUri string
30 ClientName string
31 ClientUri string
32 Posthog posthog.Client
33 Db *db.DB
34 Enforcer *rbac.Enforcer
35 IdResolver *idresolver.Resolver
36 Logger *slog.Logger
37
38 appPasswordSession *AppPasswordSession
39 appPasswordSessionMu sync.Mutex
40}
41
42func New(config *config.Config, ph posthog.Client, db *db.DB, enforcer *rbac.Enforcer, res *idresolver.Resolver, logger *slog.Logger) (*OAuth, error) {
43 var oauthConfig oauth.ClientConfig
44 var clientUri string
45 if config.Core.Dev {
46 clientUri = "http://127.0.0.1:3000"
47 callbackUri := clientUri + "/oauth/callback"
48 oauthConfig = oauth.NewLocalhostConfig(callbackUri, TangledScopes)
49 } else {
50 clientUri = "https://" + config.Core.AppviewHost
51 clientId := fmt.Sprintf("%s/oauth/client-metadata.json", clientUri)
52 callbackUri := clientUri + "/oauth/callback"
53 oauthConfig = oauth.NewPublicConfig(clientId, callbackUri, TangledScopes)
54 }
55
56 // configure client secret
57 priv, err := atcrypto.ParsePrivateMultibase(config.OAuth.ClientSecret)
58 if err != nil {
59 return nil, err
60 }
61 if err := oauthConfig.SetClientSecret(priv, config.OAuth.ClientKid); err != nil {
62 return nil, err
63 }
64
65 jwksUri := clientUri + "/oauth/jwks.json"
66
67 authStore, err := NewRedisStore(&RedisStoreConfig{
68 RedisURL: config.Redis.ToURL(),
69 SessionExpiryDuration: time.Hour * 24 * 90,
70 SessionInactivityDuration: time.Hour * 24 * 14,
71 AuthRequestExpiryDuration: time.Minute * 30,
72 })
73 if err != nil {
74 return nil, err
75 }
76
77 sessStore := sessions.NewCookieStore([]byte(config.Core.CookieSecret))
78
79 clientApp := oauth.NewClientApp(&oauthConfig, authStore)
80 clientApp.Dir = res.Directory()
81 // allow non-public transports in dev mode
82 if config.Core.Dev {
83 clientApp.Resolver.Client.Transport = http.DefaultTransport
84 }
85
86 clientName := config.Core.AppviewName
87
88 logger.Info("oauth setup successfully", "IsConfidential", clientApp.Config.IsConfidential())
89 return &OAuth{
90 ClientApp: clientApp,
91 Config: config,
92 SessStore: sessStore,
93 JwksUri: jwksUri,
94 ClientName: clientName,
95 ClientUri: clientUri,
96 Posthog: ph,
97 Db: db,
98 Enforcer: enforcer,
99 IdResolver: res,
100 Logger: logger,
101 }, nil
102}
103
104func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, sessData *oauth.ClientSessionData) error {
105 userSession, err := o.SessStore.Get(r, SessionName)
106 if err != nil {
107 return err
108 }
109
110 userSession.Values[SessionDid] = sessData.AccountDID.String()
111 userSession.Values[SessionPds] = sessData.HostURL
112 userSession.Values[SessionId] = sessData.SessionID
113 userSession.Values[SessionAuthenticated] = true
114
115 if err := userSession.Save(r, w); err != nil {
116 return err
117 }
118
119 handle := ""
120 resolved, err := o.IdResolver.ResolveIdent(r.Context(), sessData.AccountDID.String())
121 if err == nil && resolved.Handle.String() != "" {
122 handle = resolved.Handle.String()
123 }
124
125 registry := o.GetAccounts(r)
126 if err := registry.AddAccount(sessData.AccountDID.String(), handle, sessData.SessionID); err != nil {
127 return err
128 }
129 return o.SaveAccounts(w, r, registry)
130}
131
132func (o *OAuth) ResumeSession(r *http.Request) (*oauth.ClientSession, error) {
133 userSession, err := o.SessStore.Get(r, SessionName)
134 if err != nil {
135 return nil, fmt.Errorf("error getting user session: %w", err)
136 }
137 if userSession.IsNew {
138 return nil, fmt.Errorf("no session available for user")
139 }
140
141 d := userSession.Values[SessionDid].(string)
142 sessDid, err := syntax.ParseDID(d)
143 if err != nil {
144 return nil, fmt.Errorf("malformed DID in session cookie '%s': %w", d, err)
145 }
146
147 sessId := userSession.Values[SessionId].(string)
148
149 clientSess, err := o.ClientApp.ResumeSession(r.Context(), sessDid, sessId)
150 if err != nil {
151 return nil, fmt.Errorf("failed to resume session: %w", err)
152 }
153
154 return clientSess, nil
155}
156
157func (o *OAuth) DeleteSession(w http.ResponseWriter, r *http.Request) error {
158 userSession, err := o.SessStore.Get(r, SessionName)
159 if err != nil {
160 return fmt.Errorf("error getting user session: %w", err)
161 }
162 if userSession.IsNew {
163 return fmt.Errorf("no session available for user")
164 }
165
166 d := userSession.Values[SessionDid].(string)
167 sessDid, err := syntax.ParseDID(d)
168 if err != nil {
169 return fmt.Errorf("malformed DID in session cookie '%s': %w", d, err)
170 }
171
172 sessId := userSession.Values[SessionId].(string)
173
174 // delete the session
175 err1 := o.ClientApp.Logout(r.Context(), sessDid, sessId)
176 if err1 != nil {
177 err1 = fmt.Errorf("failed to logout: %w", err1)
178 }
179
180 // remove the cookie
181 userSession.Options.MaxAge = -1
182 err2 := o.SessStore.Save(r, w, userSession)
183 if err2 != nil {
184 err2 = fmt.Errorf("failed to save into session store: %w", err2)
185 }
186
187 return errors.Join(err1, err2)
188}
189
190func (o *OAuth) SwitchAccount(w http.ResponseWriter, r *http.Request, targetDid string) error {
191 registry := o.GetAccounts(r)
192 account := registry.FindAccount(targetDid)
193 if account == nil {
194 return fmt.Errorf("account not found in registry: %s", targetDid)
195 }
196
197 did, err := syntax.ParseDID(targetDid)
198 if err != nil {
199 return fmt.Errorf("invalid DID: %w", err)
200 }
201
202 sess, err := o.ClientApp.ResumeSession(r.Context(), did, account.SessionId)
203 if err != nil {
204 registry.RemoveAccount(targetDid)
205 _ = o.SaveAccounts(w, r, registry)
206 return fmt.Errorf("session expired for account: %w", err)
207 }
208
209 userSession, err := o.SessStore.Get(r, SessionName)
210 if err != nil {
211 return err
212 }
213
214 userSession.Values[SessionDid] = sess.Data.AccountDID.String()
215 userSession.Values[SessionPds] = sess.Data.HostURL
216 userSession.Values[SessionId] = sess.Data.SessionID
217 userSession.Values[SessionAuthenticated] = true
218
219 return userSession.Save(r, w)
220}
221
222func (o *OAuth) RemoveAccount(w http.ResponseWriter, r *http.Request, targetDid string) error {
223 registry := o.GetAccounts(r)
224 account := registry.FindAccount(targetDid)
225 if account == nil {
226 return nil
227 }
228
229 did, err := syntax.ParseDID(targetDid)
230 if err == nil {
231 _ = o.ClientApp.Logout(r.Context(), did, account.SessionId)
232 }
233
234 registry.RemoveAccount(targetDid)
235 return o.SaveAccounts(w, r, registry)
236}
237
238type User struct {
239 Did string
240 Pds string
241}
242
243func (o *OAuth) GetUser(r *http.Request) *User {
244 sess, err := o.ResumeSession(r)
245 if err != nil {
246 return nil
247 }
248
249 return &User{
250 Did: sess.Data.AccountDID.String(),
251 Pds: sess.Data.HostURL,
252 }
253}
254
255func (o *OAuth) GetDid(r *http.Request) string {
256 if u := o.GetMultiAccountUser(r); u != nil {
257 return u.Did()
258 }
259
260 return ""
261}
262
263func (o *OAuth) AuthorizedClient(r *http.Request) (*atpclient.APIClient, error) {
264 session, err := o.ResumeSession(r)
265 if err != nil {
266 return nil, fmt.Errorf("error getting session: %w", err)
267 }
268 return session.APIClient(), nil
269}
270
271// this is a higher level abstraction on ServerGetServiceAuth
272type ServiceClientOpts struct {
273 service string
274 exp int64
275 lxm string
276 dev bool
277 timeout time.Duration
278}
279
280type ServiceClientOpt func(*ServiceClientOpts)
281
282func DefaultServiceClientOpts() ServiceClientOpts {
283 return ServiceClientOpts{
284 timeout: time.Second * 5,
285 }
286}
287
288func WithService(service string) ServiceClientOpt {
289 return func(s *ServiceClientOpts) {
290 s.service = service
291 }
292}
293
294// Specify the Duration in seconds for the expiry of this token
295//
296// The time of expiry is calculated as time.Now().Unix() + exp
297func WithExp(exp int64) ServiceClientOpt {
298 return func(s *ServiceClientOpts) {
299 s.exp = time.Now().Unix() + exp
300 }
301}
302
303func WithLxm(lxm string) ServiceClientOpt {
304 return func(s *ServiceClientOpts) {
305 s.lxm = lxm
306 }
307}
308
309func WithDev(dev bool) ServiceClientOpt {
310 return func(s *ServiceClientOpts) {
311 s.dev = dev
312 }
313}
314
315func WithTimeout(timeout time.Duration) ServiceClientOpt {
316 return func(s *ServiceClientOpts) {
317 s.timeout = timeout
318 }
319}
320
321func (s *ServiceClientOpts) Audience() string {
322 return fmt.Sprintf("did:web:%s", s.service)
323}
324
325func (s *ServiceClientOpts) Host() string {
326 scheme := "https://"
327 if s.dev {
328 scheme = "http://"
329 }
330
331 return scheme + s.service
332}
333
334func (o *OAuth) ServiceClient(r *http.Request, os ...ServiceClientOpt) (*xrpc.Client, error) {
335 opts := DefaultServiceClientOpts()
336 for _, o := range os {
337 o(&opts)
338 }
339
340 client, err := o.AuthorizedClient(r)
341 if err != nil {
342 return nil, err
343 }
344
345 // force expiry to atleast 60 seconds in the future
346 sixty := time.Now().Unix() + 60
347 if opts.exp < sixty {
348 opts.exp = sixty
349 }
350
351 resp, err := comatproto.ServerGetServiceAuth(r.Context(), client, opts.Audience(), opts.exp, opts.lxm)
352 if err != nil {
353 return nil, err
354 }
355
356 return &xrpc.Client{
357 Auth: &xrpc.AuthInfo{
358 AccessJwt: resp.Token,
359 },
360 Host: opts.Host(),
361 Client: &http.Client{
362 Timeout: opts.timeout,
363 },
364 }, nil
365}