Monorepo for Tangled
1package oauth
2
3import (
4 "errors"
5 "fmt"
6 "log/slog"
7 "net/http"
8 "time"
9
10 comatproto "github.com/bluesky-social/indigo/api/atproto"
11 "github.com/bluesky-social/indigo/atproto/auth/oauth"
12 atpclient "github.com/bluesky-social/indigo/atproto/client"
13 atcrypto "github.com/bluesky-social/indigo/atproto/crypto"
14 "github.com/bluesky-social/indigo/atproto/syntax"
15 xrpc "github.com/bluesky-social/indigo/xrpc"
16 "github.com/gorilla/sessions"
17 "github.com/posthog/posthog-go"
18 "tangled.org/core/appview/config"
19 "tangled.org/core/appview/db"
20 "tangled.org/core/idresolver"
21 "tangled.org/core/rbac"
22)
23
24type OAuth struct {
25 ClientApp *oauth.ClientApp
26 SessStore *sessions.CookieStore
27 Config *config.Config
28 JwksUri string
29 ClientName string
30 ClientUri string
31 Posthog posthog.Client
32 Db *db.DB
33 Enforcer *rbac.Enforcer
34 IdResolver *idresolver.Resolver
35 Logger *slog.Logger
36}
37
38func New(config *config.Config, ph posthog.Client, db *db.DB, enforcer *rbac.Enforcer, res *idresolver.Resolver, logger *slog.Logger) (*OAuth, error) {
39 var oauthConfig oauth.ClientConfig
40 var clientUri string
41 if config.Core.Dev {
42 clientUri = "http://127.0.0.1:3000"
43 callbackUri := clientUri + "/oauth/callback"
44 oauthConfig = oauth.NewLocalhostConfig(callbackUri, TangledScopes)
45 } else {
46 clientUri = "https://" + config.Core.AppviewHost
47 clientId := fmt.Sprintf("%s/oauth/client-metadata.json", clientUri)
48 callbackUri := clientUri + "/oauth/callback"
49 oauthConfig = oauth.NewPublicConfig(clientId, callbackUri, TangledScopes)
50 }
51
52 // configure client secret
53 priv, err := atcrypto.ParsePrivateMultibase(config.OAuth.ClientSecret)
54 if err != nil {
55 return nil, err
56 }
57 if err := oauthConfig.SetClientSecret(priv, config.OAuth.ClientKid); err != nil {
58 return nil, err
59 }
60
61 jwksUri := clientUri + "/oauth/jwks.json"
62
63 authStore, err := NewRedisStore(&RedisStoreConfig{
64 RedisURL: config.Redis.ToURL(),
65 SessionExpiryDuration: time.Hour * 24 * 90,
66 SessionInactivityDuration: time.Hour * 24 * 14,
67 AuthRequestExpiryDuration: time.Minute * 30,
68 })
69 if err != nil {
70 return nil, err
71 }
72
73 sessStore := sessions.NewCookieStore([]byte(config.Core.CookieSecret))
74
75 clientApp := oauth.NewClientApp(&oauthConfig, authStore)
76 clientApp.Dir = res.Directory()
77 // allow non-public transports in dev mode
78 if config.Core.Dev {
79 clientApp.Resolver.Client.Transport = http.DefaultTransport
80 }
81
82 clientName := config.Core.AppviewName
83
84 logger.Info("oauth setup successfully", "IsConfidential", clientApp.Config.IsConfidential())
85 return &OAuth{
86 ClientApp: clientApp,
87 Config: config,
88 SessStore: sessStore,
89 JwksUri: jwksUri,
90 ClientName: clientName,
91 ClientUri: clientUri,
92 Posthog: ph,
93 Db: db,
94 Enforcer: enforcer,
95 IdResolver: res,
96 Logger: logger,
97 }, nil
98}
99
100func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, sessData *oauth.ClientSessionData) error {
101 userSession, err := o.SessStore.Get(r, SessionName)
102 if err != nil {
103 return err
104 }
105
106 userSession.Values[SessionDid] = sessData.AccountDID.String()
107 userSession.Values[SessionPds] = sessData.HostURL
108 userSession.Values[SessionId] = sessData.SessionID
109 userSession.Values[SessionAuthenticated] = true
110
111 if err := userSession.Save(r, w); err != nil {
112 return err
113 }
114
115 handle := ""
116 resolved, err := o.IdResolver.ResolveIdent(r.Context(), sessData.AccountDID.String())
117 if err == nil && resolved.Handle.String() != "" {
118 handle = resolved.Handle.String()
119 }
120
121 registry := o.GetAccounts(r)
122 if err := registry.AddAccount(sessData.AccountDID.String(), handle, sessData.SessionID); err != nil {
123 return err
124 }
125 return o.SaveAccounts(w, r, registry)
126}
127
128func (o *OAuth) ResumeSession(r *http.Request) (*oauth.ClientSession, error) {
129 userSession, err := o.SessStore.Get(r, SessionName)
130 if err != nil {
131 return nil, fmt.Errorf("error getting user session: %w", err)
132 }
133 if userSession.IsNew {
134 return nil, fmt.Errorf("no session available for user")
135 }
136
137 d := userSession.Values[SessionDid].(string)
138 sessDid, err := syntax.ParseDID(d)
139 if err != nil {
140 return nil, fmt.Errorf("malformed DID in session cookie '%s': %w", d, err)
141 }
142
143 sessId := userSession.Values[SessionId].(string)
144
145 clientSess, err := o.ClientApp.ResumeSession(r.Context(), sessDid, sessId)
146 if err != nil {
147 return nil, fmt.Errorf("failed to resume session: %w", err)
148 }
149
150 return clientSess, nil
151}
152
153func (o *OAuth) DeleteSession(w http.ResponseWriter, r *http.Request) error {
154 userSession, err := o.SessStore.Get(r, SessionName)
155 if err != nil {
156 return fmt.Errorf("error getting user session: %w", err)
157 }
158 if userSession.IsNew {
159 return fmt.Errorf("no session available for user")
160 }
161
162 d := userSession.Values[SessionDid].(string)
163 sessDid, err := syntax.ParseDID(d)
164 if err != nil {
165 return fmt.Errorf("malformed DID in session cookie '%s': %w", d, err)
166 }
167
168 sessId := userSession.Values[SessionId].(string)
169
170 // delete the session
171 err1 := o.ClientApp.Logout(r.Context(), sessDid, sessId)
172 if err1 != nil {
173 err1 = fmt.Errorf("failed to logout: %w", err1)
174 }
175
176 // remove the cookie
177 userSession.Options.MaxAge = -1
178 err2 := o.SessStore.Save(r, w, userSession)
179 if err2 != nil {
180 err2 = fmt.Errorf("failed to save into session store: %w", err2)
181 }
182
183 return errors.Join(err1, err2)
184}
185
186func (o *OAuth) SwitchAccount(w http.ResponseWriter, r *http.Request, targetDid string) error {
187 registry := o.GetAccounts(r)
188 account := registry.FindAccount(targetDid)
189 if account == nil {
190 return fmt.Errorf("account not found in registry: %s", targetDid)
191 }
192
193 did, err := syntax.ParseDID(targetDid)
194 if err != nil {
195 return fmt.Errorf("invalid DID: %w", err)
196 }
197
198 sess, err := o.ClientApp.ResumeSession(r.Context(), did, account.SessionId)
199 if err != nil {
200 registry.RemoveAccount(targetDid)
201 _ = o.SaveAccounts(w, r, registry)
202 return fmt.Errorf("session expired for account: %w", err)
203 }
204
205 userSession, err := o.SessStore.Get(r, SessionName)
206 if err != nil {
207 return err
208 }
209
210 userSession.Values[SessionDid] = sess.Data.AccountDID.String()
211 userSession.Values[SessionPds] = sess.Data.HostURL
212 userSession.Values[SessionId] = sess.Data.SessionID
213 userSession.Values[SessionAuthenticated] = true
214
215 return userSession.Save(r, w)
216}
217
218func (o *OAuth) RemoveAccount(w http.ResponseWriter, r *http.Request, targetDid string) error {
219 registry := o.GetAccounts(r)
220 account := registry.FindAccount(targetDid)
221 if account == nil {
222 return nil
223 }
224
225 did, err := syntax.ParseDID(targetDid)
226 if err == nil {
227 _ = o.ClientApp.Logout(r.Context(), did, account.SessionId)
228 }
229
230 registry.RemoveAccount(targetDid)
231 return o.SaveAccounts(w, r, registry)
232}
233
234type User struct {
235 Did string
236 Pds string
237}
238
239func (o *OAuth) GetUser(r *http.Request) *User {
240 sess, err := o.ResumeSession(r)
241 if err != nil {
242 return nil
243 }
244
245 return &User{
246 Did: sess.Data.AccountDID.String(),
247 Pds: sess.Data.HostURL,
248 }
249}
250
251func (o *OAuth) GetDid(r *http.Request) string {
252 if u := o.GetMultiAccountUser(r); u != nil {
253 return u.Did()
254 }
255
256 return ""
257}
258
259func (o *OAuth) AuthorizedClient(r *http.Request) (*atpclient.APIClient, error) {
260 session, err := o.ResumeSession(r)
261 if err != nil {
262 return nil, fmt.Errorf("error getting session: %w", err)
263 }
264 return session.APIClient(), nil
265}
266
267// this is a higher level abstraction on ServerGetServiceAuth
268type ServiceClientOpts struct {
269 service string
270 exp int64
271 lxm string
272 dev bool
273 timeout time.Duration
274}
275
276type ServiceClientOpt func(*ServiceClientOpts)
277
278func DefaultServiceClientOpts() ServiceClientOpts {
279 return ServiceClientOpts{
280 timeout: time.Second * 5,
281 }
282}
283
284func WithService(service string) ServiceClientOpt {
285 return func(s *ServiceClientOpts) {
286 s.service = service
287 }
288}
289
290// Specify the Duration in seconds for the expiry of this token
291//
292// The time of expiry is calculated as time.Now().Unix() + exp
293func WithExp(exp int64) ServiceClientOpt {
294 return func(s *ServiceClientOpts) {
295 s.exp = time.Now().Unix() + exp
296 }
297}
298
299func WithLxm(lxm string) ServiceClientOpt {
300 return func(s *ServiceClientOpts) {
301 s.lxm = lxm
302 }
303}
304
305func WithDev(dev bool) ServiceClientOpt {
306 return func(s *ServiceClientOpts) {
307 s.dev = dev
308 }
309}
310
311func WithTimeout(timeout time.Duration) ServiceClientOpt {
312 return func(s *ServiceClientOpts) {
313 s.timeout = timeout
314 }
315}
316
317func (s *ServiceClientOpts) Audience() string {
318 return fmt.Sprintf("did:web:%s", s.service)
319}
320
321func (s *ServiceClientOpts) Host() string {
322 scheme := "https://"
323 if s.dev {
324 scheme = "http://"
325 }
326
327 return scheme + s.service
328}
329
330func (o *OAuth) ServiceClient(r *http.Request, os ...ServiceClientOpt) (*xrpc.Client, error) {
331 opts := DefaultServiceClientOpts()
332 for _, o := range os {
333 o(&opts)
334 }
335
336 client, err := o.AuthorizedClient(r)
337 if err != nil {
338 return nil, err
339 }
340
341 // force expiry to atleast 60 seconds in the future
342 sixty := time.Now().Unix() + 60
343 if opts.exp < sixty {
344 opts.exp = sixty
345 }
346
347 resp, err := comatproto.ServerGetServiceAuth(r.Context(), client, opts.Audience(), opts.exp, opts.lxm)
348 if err != nil {
349 return nil, err
350 }
351
352 return &xrpc.Client{
353 Auth: &xrpc.AuthInfo{
354 AccessJwt: resp.Token,
355 },
356 Host: opts.Host(),
357 Client: &http.Client{
358 Timeout: opts.timeout,
359 },
360 }, nil
361}