Monorepo for Tangled
1package oauth
2
3import (
4 "errors"
5 "fmt"
6 "log/slog"
7 "net/http"
8 "time"
9
10 comatproto "github.com/bluesky-social/indigo/api/atproto"
11 "github.com/bluesky-social/indigo/atproto/auth/oauth"
12 atpclient "github.com/bluesky-social/indigo/atproto/client"
13 atcrypto "github.com/bluesky-social/indigo/atproto/crypto"
14 "github.com/bluesky-social/indigo/atproto/syntax"
15 xrpc "github.com/bluesky-social/indigo/xrpc"
16 "github.com/gorilla/sessions"
17 "github.com/posthog/posthog-go"
18 "tangled.org/core/appview/config"
19 "tangled.org/core/appview/db"
20 "tangled.org/core/idresolver"
21 "tangled.org/core/rbac"
22)
23
24type OAuth struct {
25 ClientApp *oauth.ClientApp
26 SessStore *sessions.CookieStore
27 Config *config.Config
28 JwksUri string
29 ClientName string
30 ClientUri string
31 Posthog posthog.Client
32 Db *db.DB
33 Enforcer *rbac.Enforcer
34 IdResolver *idresolver.Resolver
35 Logger *slog.Logger
36}
37
38func New(config *config.Config, ph posthog.Client, db *db.DB, enforcer *rbac.Enforcer, res *idresolver.Resolver, logger *slog.Logger) (*OAuth, error) {
39 var oauthConfig oauth.ClientConfig
40 var clientUri string
41 if config.Core.Dev {
42 clientUri = "http://127.0.0.1:3000"
43 callbackUri := clientUri + "/oauth/callback"
44 oauthConfig = oauth.NewLocalhostConfig(callbackUri, []string{"atproto", "transition:generic"})
45 } else {
46 clientUri = config.Core.AppviewHost
47 clientId := fmt.Sprintf("%s/oauth/client-metadata.json", clientUri)
48 callbackUri := clientUri + "/oauth/callback"
49 oauthConfig = oauth.NewPublicConfig(clientId, callbackUri, []string{"atproto", "transition:generic"})
50 }
51
52 // configure client secret
53 priv, err := atcrypto.ParsePrivateMultibase(config.OAuth.ClientSecret)
54 if err != nil {
55 return nil, err
56 }
57 if err := oauthConfig.SetClientSecret(priv, config.OAuth.ClientKid); err != nil {
58 return nil, err
59 }
60
61 jwksUri := clientUri + "/oauth/jwks.json"
62
63 authStore, err := NewRedisStore(&RedisStoreConfig{
64 RedisURL: config.Redis.ToURL(),
65 SessionExpiryDuration: time.Hour * 24 * 90,
66 SessionInactivityDuration: time.Hour * 24 * 14,
67 AuthRequestExpiryDuration: time.Minute * 30,
68 })
69 if err != nil {
70 return nil, err
71 }
72
73 sessStore := sessions.NewCookieStore([]byte(config.Core.CookieSecret))
74
75 clientApp := oauth.NewClientApp(&oauthConfig, authStore)
76 clientApp.Dir = res.Directory()
77 // allow non-public transports in dev mode
78 if config.Core.Dev {
79 clientApp.Resolver.Client.Transport = http.DefaultTransport
80 }
81
82 clientName := config.Core.AppviewName
83
84 logger.Info("oauth setup successfully", "IsConfidential", clientApp.Config.IsConfidential())
85 return &OAuth{
86 ClientApp: clientApp,
87 Config: config,
88 SessStore: sessStore,
89 JwksUri: jwksUri,
90 ClientName: clientName,
91 ClientUri: clientUri,
92 Posthog: ph,
93 Db: db,
94 Enforcer: enforcer,
95 IdResolver: res,
96 Logger: logger,
97 }, nil
98}
99
100func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, sessData *oauth.ClientSessionData) error {
101 userSession, err := o.SessStore.Get(r, SessionName)
102 if err != nil {
103 return err
104 }
105
106 userSession.Values[SessionDid] = sessData.AccountDID.String()
107 userSession.Values[SessionPds] = sessData.HostURL
108 userSession.Values[SessionId] = sessData.SessionID
109 userSession.Values[SessionAuthenticated] = true
110
111 if err := userSession.Save(r, w); err != nil {
112 return err
113 }
114
115 handle := ""
116 resolved, err := o.IdResolver.ResolveIdent(r.Context(), sessData.AccountDID.String())
117 if err == nil && resolved.Handle.String() != "" {
118 handle = resolved.Handle.String()
119 }
120
121 registry := o.GetAccounts(r)
122 if err := registry.AddAccount(sessData.AccountDID.String(), handle, sessData.SessionID); err != nil {
123 return err
124 }
125 return o.SaveAccounts(w, r, registry)
126}
127
128func (o *OAuth) ResumeSession(r *http.Request) (*oauth.ClientSession, error) {
129 userSession, err := o.SessStore.Get(r, SessionName)
130 if err != nil {
131 return nil, fmt.Errorf("error getting user session: %w", err)
132 }
133 if userSession.IsNew {
134 return nil, fmt.Errorf("no session available for user")
135 }
136
137 d := userSession.Values[SessionDid].(string)
138 sessDid, err := syntax.ParseDID(d)
139 if err != nil {
140 return nil, fmt.Errorf("malformed DID in session cookie '%s': %w", d, err)
141 }
142
143 sessId := userSession.Values[SessionId].(string)
144
145 clientSess, err := o.ClientApp.ResumeSession(r.Context(), sessDid, sessId)
146 if err != nil {
147 return nil, fmt.Errorf("failed to resume session: %w", err)
148 }
149
150 return clientSess, nil
151}
152
153func (o *OAuth) DeleteSession(w http.ResponseWriter, r *http.Request) error {
154 userSession, err := o.SessStore.Get(r, SessionName)
155 if err != nil {
156 return fmt.Errorf("error getting user session: %w", err)
157 }
158 if userSession.IsNew {
159 return fmt.Errorf("no session available for user")
160 }
161
162 d := userSession.Values[SessionDid].(string)
163 sessDid, err := syntax.ParseDID(d)
164 if err != nil {
165 return fmt.Errorf("malformed DID in session cookie '%s': %w", d, err)
166 }
167
168 sessId := userSession.Values[SessionId].(string)
169
170 // delete the session
171 err1 := o.ClientApp.Logout(r.Context(), sessDid, sessId)
172
173 // remove the cookie
174 userSession.Options.MaxAge = -1
175 err2 := o.SessStore.Save(r, w, userSession)
176
177 return errors.Join(err1, err2)
178}
179
180func (o *OAuth) SwitchAccount(w http.ResponseWriter, r *http.Request, targetDid string) error {
181 registry := o.GetAccounts(r)
182 account := registry.FindAccount(targetDid)
183 if account == nil {
184 return fmt.Errorf("account not found in registry: %s", targetDid)
185 }
186
187 did, err := syntax.ParseDID(targetDid)
188 if err != nil {
189 return fmt.Errorf("invalid DID: %w", err)
190 }
191
192 sess, err := o.ClientApp.ResumeSession(r.Context(), did, account.SessionId)
193 if err != nil {
194 registry.RemoveAccount(targetDid)
195 _ = o.SaveAccounts(w, r, registry)
196 return fmt.Errorf("session expired for account: %w", err)
197 }
198
199 userSession, err := o.SessStore.Get(r, SessionName)
200 if err != nil {
201 return err
202 }
203
204 userSession.Values[SessionDid] = sess.Data.AccountDID.String()
205 userSession.Values[SessionPds] = sess.Data.HostURL
206 userSession.Values[SessionId] = sess.Data.SessionID
207 userSession.Values[SessionAuthenticated] = true
208
209 return userSession.Save(r, w)
210}
211
212func (o *OAuth) RemoveAccount(w http.ResponseWriter, r *http.Request, targetDid string) error {
213 registry := o.GetAccounts(r)
214 account := registry.FindAccount(targetDid)
215 if account == nil {
216 return nil
217 }
218
219 did, err := syntax.ParseDID(targetDid)
220 if err == nil {
221 _ = o.ClientApp.Logout(r.Context(), did, account.SessionId)
222 }
223
224 registry.RemoveAccount(targetDid)
225 return o.SaveAccounts(w, r, registry)
226}
227
228type User struct {
229 Did string
230 Pds string
231}
232
233func (o *OAuth) GetUser(r *http.Request) *User {
234 sess, err := o.ResumeSession(r)
235 if err != nil {
236 return nil
237 }
238
239 return &User{
240 Did: sess.Data.AccountDID.String(),
241 Pds: sess.Data.HostURL,
242 }
243}
244
245func (o *OAuth) GetDid(r *http.Request) string {
246 if u := o.GetMultiAccountUser(r); u != nil {
247 return u.Did()
248 }
249
250 return ""
251}
252
253func (o *OAuth) AuthorizedClient(r *http.Request) (*atpclient.APIClient, error) {
254 session, err := o.ResumeSession(r)
255 if err != nil {
256 return nil, fmt.Errorf("error getting session: %w", err)
257 }
258 return session.APIClient(), nil
259}
260
261// this is a higher level abstraction on ServerGetServiceAuth
262type ServiceClientOpts struct {
263 service string
264 exp int64
265 lxm string
266 dev bool
267 timeout time.Duration
268}
269
270type ServiceClientOpt func(*ServiceClientOpts)
271
272func DefaultServiceClientOpts() ServiceClientOpts {
273 return ServiceClientOpts{
274 timeout: time.Second * 5,
275 }
276}
277
278func WithService(service string) ServiceClientOpt {
279 return func(s *ServiceClientOpts) {
280 s.service = service
281 }
282}
283
284// Specify the Duration in seconds for the expiry of this token
285//
286// The time of expiry is calculated as time.Now().Unix() + exp
287func WithExp(exp int64) ServiceClientOpt {
288 return func(s *ServiceClientOpts) {
289 s.exp = time.Now().Unix() + exp
290 }
291}
292
293func WithLxm(lxm string) ServiceClientOpt {
294 return func(s *ServiceClientOpts) {
295 s.lxm = lxm
296 }
297}
298
299func WithDev(dev bool) ServiceClientOpt {
300 return func(s *ServiceClientOpts) {
301 s.dev = dev
302 }
303}
304
305func WithTimeout(timeout time.Duration) ServiceClientOpt {
306 return func(s *ServiceClientOpts) {
307 s.timeout = timeout
308 }
309}
310
311func (s *ServiceClientOpts) Audience() string {
312 return fmt.Sprintf("did:web:%s", s.service)
313}
314
315func (s *ServiceClientOpts) Host() string {
316 scheme := "https://"
317 if s.dev {
318 scheme = "http://"
319 }
320
321 return scheme + s.service
322}
323
324func (o *OAuth) ServiceClient(r *http.Request, os ...ServiceClientOpt) (*xrpc.Client, error) {
325 opts := DefaultServiceClientOpts()
326 for _, o := range os {
327 o(&opts)
328 }
329
330 client, err := o.AuthorizedClient(r)
331 if err != nil {
332 return nil, err
333 }
334
335 // force expiry to atleast 60 seconds in the future
336 sixty := time.Now().Unix() + 60
337 if opts.exp < sixty {
338 opts.exp = sixty
339 }
340
341 resp, err := comatproto.ServerGetServiceAuth(r.Context(), client, opts.Audience(), opts.exp, opts.lxm)
342 if err != nil {
343 return nil, err
344 }
345
346 return &xrpc.Client{
347 Auth: &xrpc.AuthInfo{
348 AccessJwt: resp.Token,
349 },
350 Host: opts.Host(),
351 Client: &http.Client{
352 Timeout: opts.timeout,
353 },
354 }, nil
355}